fixes #171
This commit is contained in:
parent
0b3c02fe8d
commit
334ae9e1cb
|
@ -21,10 +21,12 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||
mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
isleaf::Bool
|
||||
data::T
|
||||
grad::T
|
||||
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
|
||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
|
||||
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
||||
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
||||
end
|
||||
|
||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::Tracked)
|
||||
x.isleaf && return
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||
else
|
||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||
end
|
||||
|
@ -25,6 +29,7 @@ accum!(x, Δ) = x .+ Δ
|
|||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ)
|
||||
x.isleaf && (accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
|
|
|
@ -69,6 +69,7 @@ tracker(xs::TrackedTuple) = xs.tracker
|
|||
|
||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||
|
||||
|
|
|
@ -57,4 +57,14 @@ end
|
|||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
|
||||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
Flux.back!(l)
|
||||
@test x.grad == [8]
|
||||
x.grad .= 0
|
||||
Flux.back!(l)
|
||||
@test x.grad == [8]
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue