From 8e59160df6e8126282d251379f421e84e0021d86 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sun, 20 Aug 2017 13:35:20 +0100 Subject: [PATCH] inferable broadcast --- src/Tracker/lib.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 55a1cdaf..0e8c573f 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -33,8 +33,11 @@ dualify(xs, n) = xs dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps)) function tracked_broadcast(f, args::Vararg{Any,N}) where N - dargs = ntuple(i -> dualify(args[i], ntuple(j -> i==j, Val{N})), Val{N}) - TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...)) + dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) + # TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...)) + # Works around a 0.6 type inference issue + b = Broadcasted(broadcast(f, dargs...)) + TrackedArray(Call(b, args...), b()) end function back!(b::Broadcasted, Δ, args...)