This commit is contained in:
Mike J Innes 2018-02-12 12:31:15 +00:00
parent 0b3c02fe8d
commit 334ae9e1cb
4 changed files with 20 additions and 2 deletions

View File

@ -21,10 +21,12 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
mutable struct Tracked{T}
ref::UInt32
f::Call
isleaf::Bool
data::T
grad::T
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
end
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)

View File

@ -1,11 +1,15 @@
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
else
isdefined(x, :grad) || (x.grad = init_grad(x.data))
end
@ -25,6 +29,7 @@ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ)
x.isleaf && (accum!(x.grad, Δ); return)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)

View File

@ -69,6 +69,7 @@ tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))

View File

@ -57,4 +57,14 @@ end
@test (param([1,2,3]) .< 2) == [true, false, false]
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l)
@test x.grad == [8]
end
end #testset