gpu-friendly dualify

This commit is contained in:
Mike J Innes 2017-08-21 16:35:39 +01:00
parent 65a49188e6
commit 227e41c37b

View File

@ -34,7 +34,7 @@ end
(b::Broadcasted)(xs...) = map(x -> x.value, b.data) (b::Broadcasted)(xs...) = map(x -> x.value, b.data)
dualify(xs, n) = xs dualify(xs, n) = xs
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps)) dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
function tracked_broadcast(f, args::Vararg{Any,N}) where N function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))