gpu-friendly dualify

This commit is contained in:
Mike J Innes 2017-08-21 16:35:39 +01:00
parent 65a49188e6
commit 227e41c37b

View File

@ -34,7 +34,7 @@ end
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
dualify(xs, n) = xs
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps))
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))