diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 3cfdd382..41b7e676 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -34,7 +34,7 @@ end (b::Broadcasted)(xs...) = map(x -> x.value, b.data) dualify(xs, n) = xs -dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps)) +dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs)) function tracked_broadcast(f, args::Vararg{Any,N}) where N dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))