gpu-friendly dualify
This commit is contained in:
parent
65a49188e6
commit
227e41c37b
@ -34,7 +34,7 @@ end
|
||||
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps))
|
||||
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
||||
|
||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
|
Loading…
Reference in New Issue
Block a user