auto-collect in forward

This commit is contained in:
Mike J Innes 2019-02-04 10:37:02 +00:00
parent 838070968e
commit cfe6859186
3 changed files with 11 additions and 1 deletions

View File

@ -147,8 +147,10 @@ end
back(::Grads, ::Nothing, _) = return
collectmemaybe(xs) = xs
function forward(f, ps::Params)
y = f()
y = collectmemaybe(f())
y, function (Δ)
g = Grads(ps)
if istracked(y)

View File

@ -155,3 +155,6 @@ end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
end
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)

View File

@ -323,4 +323,9 @@ end
end == ([3, 2],)
end
@testset "Custom Sensitivities" begin
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
@test back([1, 1]) == (32,)
end
end #testset