out of place gradients for collect
This commit is contained in:
parent
6cdf4ff56a
commit
62d594af43
|
@ -115,3 +115,7 @@ end
|
|||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach(back, c.args[1], data(Δ))
|
||||
end
|
||||
|
||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||
end
|
||||
|
|
|
@ -232,6 +232,11 @@ Tracker.back!(b)
|
|||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
|
||||
@test Tracker.gradient(2, 3) do x, y
|
||||
xy = Tracker.collect([x, y])
|
||||
xy[1]*xy[2]
|
||||
end == (3, 2)
|
||||
end
|
||||
|
||||
# Gradient Hooks
|
||||
|
|
Loading…
Reference in New Issue