auto-collect in forward
This commit is contained in:
parent
838070968e
commit
cfe6859186
@ -147,8 +147,10 @@ end
|
|||||||
|
|
||||||
back(::Grads, ::Nothing, _) = return
|
back(::Grads, ::Nothing, _) = return
|
||||||
|
|
||||||
|
collectmemaybe(xs) = xs
|
||||||
|
|
||||||
function forward(f, ps::Params)
|
function forward(f, ps::Params)
|
||||||
y = f()
|
y = collectmemaybe(f())
|
||||||
y, function (Δ)
|
y, function (Δ)
|
||||||
g = Grads(ps)
|
g = Grads(ps)
|
||||||
if istracked(y)
|
if istracked(y)
|
||||||
|
@ -155,3 +155,6 @@ end
|
|||||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
|
||||||
|
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)
|
||||||
|
@ -323,4 +323,9 @@ end
|
|||||||
end == ([3, 2],)
|
end == ([3, 2],)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Custom Sensitivities" begin
|
||||||
|
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
|
||||||
|
@test back([1, 1]) == (32,)
|
||||||
|
end
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user