out of place gradients for collect

This commit is contained in:
Mike J Innes 2018-08-07 22:09:20 +01:00
parent 6cdf4ff56a
commit 62d594af43
2 changed files with 9 additions and 0 deletions

View File

@ -115,3 +115,7 @@ end
function back_(c::Call{typeof(collect)}, Δ)
foreach(back, c.args[1], data(Δ))
end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
end

View File

@ -232,6 +232,11 @@ Tracker.back!(b)
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
@test Tracker.gradient(2, 3) do x, y
xy = Tracker.collect([x, y])
xy[1]*xy[2]
end == (3, 2)
end
# Gradient Hooks