inferable broadcast
This commit is contained in:
parent
18e69b33c9
commit
8e59160df6
@ -33,8 +33,11 @@ dualify(xs, n) = xs
|
|||||||
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps))
|
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps))
|
||||||
|
|
||||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||||
dargs = ntuple(i -> dualify(args[i], ntuple(j -> i==j, Val{N})), Val{N})
|
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||||
TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||||
|
# Works around a 0.6 type inference issue
|
||||||
|
b = Broadcasted(broadcast(f, dargs...))
|
||||||
|
TrackedArray(Call(b, args...), b())
|
||||||
end
|
end
|
||||||
|
|
||||||
function back!(b::Broadcasted, Δ, args...)
|
function back!(b::Broadcasted, Δ, args...)
|
||||||
|
Loading…
Reference in New Issue
Block a user