auto-collect in forward
This commit is contained in:
parent
838070968e
commit
cfe6859186
|
@ -147,8 +147,10 @@ end
|
|||
|
||||
back(::Grads, ::Nothing, _) = return
|
||||
|
||||
collectmemaybe(xs) = xs
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
y = collectmemaybe(f())
|
||||
y, function (Δ)
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
|
|
|
@ -155,3 +155,6 @@ end
|
|||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||
end
|
||||
|
||||
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
|
||||
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)
|
||||
|
|
|
@ -323,4 +323,9 @@ end
|
|||
end == ([3, 2],)
|
||||
end
|
||||
|
||||
@testset "Custom Sensitivities" begin
|
||||
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
|
||||
@test back([1, 1]) == (32,)
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue