fixes #171
This commit is contained in:
parent
0b3c02fe8d
commit
334ae9e1cb
@ -21,10 +21,12 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||||||
mutable struct Tracked{T}
|
mutable struct Tracked{T}
|
||||||
ref::UInt32
|
ref::UInt32
|
||||||
f::Call
|
f::Call
|
||||||
|
isleaf::Bool
|
||||||
data::T
|
data::T
|
||||||
grad::T
|
grad::T
|
||||||
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
|
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
||||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
|
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
||||||
|
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
init_grad(x) = zero(x)
|
init_grad(x) = zero(x)
|
||||||
|
zero_grad!(x) = zero(x)
|
||||||
|
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||||
|
|
||||||
scan(c::Call) = foreach(scan, c.args)
|
scan(c::Call) = foreach(scan, c.args)
|
||||||
|
|
||||||
function scan(x::Tracked)
|
function scan(x::Tracked)
|
||||||
|
x.isleaf && return
|
||||||
ref = x.ref += 1
|
ref = x.ref += 1
|
||||||
if ref == 1
|
if ref == 1
|
||||||
scan(x.f)
|
scan(x.f)
|
||||||
|
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||||
else
|
else
|
||||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||||
end
|
end
|
||||||
@ -25,6 +29,7 @@ accum!(x, Δ) = x .+ Δ
|
|||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
|
|
||||||
function back(x::Tracked, Δ)
|
function back(x::Tracked, Δ)
|
||||||
|
x.isleaf && (accum!(x.grad, Δ); return)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
x.grad = accum!(x.grad, Δ)
|
x.grad = accum!(x.grad, Δ)
|
||||||
|
@ -69,6 +69,7 @@ tracker(xs::TrackedTuple) = xs.tracker
|
|||||||
|
|
||||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||||
init_grad(x::Tuple) = init_grad.(x)
|
init_grad(x::Tuple) = init_grad.(x)
|
||||||
|
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||||
|
|
||||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||||
|
|
||||||
|
@ -57,4 +57,14 @@ end
|
|||||||
|
|
||||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||||
|
|
||||||
|
@testset "Intermediates" begin
|
||||||
|
x = param([1])
|
||||||
|
l = sum((x .+ x).^2)
|
||||||
|
Flux.back!(l)
|
||||||
|
@test x.grad == [8]
|
||||||
|
x.grad .= 0
|
||||||
|
Flux.back!(l)
|
||||||
|
@test x.grad == [8]
|
||||||
|
end
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user