Merge pull request #397 from FluxML/nest-bcast
Nested Derivatives of Broadcast
This commit is contained in:
commit
08fb9b7df1
@ -353,9 +353,9 @@ end
|
|||||||
eltype(y) <: Real || return y
|
eltype(y) <: Real || return y
|
||||||
eltype(y) == Bool && return y
|
eltype(y) == Bool && return y
|
||||||
function back(Δ)
|
function back(Δ)
|
||||||
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
||||||
dxs = unbroadcast.(args, Δargs)
|
dxs = map(unbroadcast, args, Δargs)
|
||||||
return nobacksies(:broadcast, dxs)
|
return dxs
|
||||||
end
|
end
|
||||||
# So we can return non-tracked arrays
|
# So we can return non-tracked arrays
|
||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
|
Loading…
Reference in New Issue
Block a user