diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index fd94bdb1..5afe9ced 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -11,9 +11,9 @@ tracker(x) = nothing istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) -data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) grad(::Void) = nothing +data(x) = x struct Call{F,As<:Tuple} func::F @@ -32,29 +32,23 @@ mutable struct Tracked{T} ref::UInt32 f::Call isleaf::Bool - data::T grad::T - Tracked{T}(f::Call, data::T) where T = new(0, f, false, data) - Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad) - Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad) + Tracked{T}(f::Call) where T = new(0, f, false) + Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) + Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad) end -Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) -Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) - istracked(x::Tracked) = true isleaf(x::Tracked) = x.f == Call() -data(x::Tracked) = x.data grad(x::Tracked) = x.grad -track(f::Call, x) = Tracked(f, x) -track(f::Call) = track(f, f()) +track(f::Call, x) = Tracked{typeof(x)}(f) function _forward end function track(f, xs...; kw...) y, back = _forward(f, data.(xs)...; kw...) - track(Call(back, xs), y) + track(Call(back, tracker.(xs)), y) end macro grad(ex) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 709f0136..dbf789ac 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad) end +data(x::TrackedArray) = x.data tracker(x::TrackedArray) = x.tracker TrackedVector{T,A} = TrackedArray{T,1,A} @@ -15,10 +16,10 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} track(c::Call, x::AbstractArray) = TrackedArray(c, x) TrackedArray(c::Call, x::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x) + TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x) TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) + TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ) TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x)) @@ -369,7 +370,7 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N map((x, Δ) -> unbroadcast(x, Δ), args, Δargs) end # So we can return non-tracked arrays - track(Call(back, args), y) + track(Call(back, tracker.(args)), y) end Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 62cae1d0..c6d1646a 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -10,8 +10,6 @@ function scan(x::Tracked) if ref == 1 scan(x.f) isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) - else - isdefined(x, :grad) || (x.grad = init_grad(x.data)) end return end @@ -25,7 +23,7 @@ function back_(c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs) + foreach(back, c.args, Δs) end back_(::Call{Void}, Δ) = nothing @@ -36,8 +34,12 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ) function back(x::Tracked, Δ) x.isleaf && (x.grad = accum!(x.grad, Δ); return) ref = x.ref -= 1 - if isdefined(x, :grad) - x.grad = accum!(x.grad, Δ) + if ref > 0 || isdefined(x, :grad) + if isdefined(x, :grad) + x.grad = accum!(x.grad, Δ) + else + x.grad = Δ + end ref == 0 && back_(x.f, x.grad) else ref == 0 && back_(x.f, Δ) @@ -45,8 +47,7 @@ function back(x::Tracked, Δ) return end -back(x, Δ) = back(tracker(x), Δ) -back(x::Void, Δ) = error("Can't backpropagate through `nothing`") +back(::Void, _) = return # Interface methods @@ -55,13 +56,13 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`") # Refcounts are also probably not safe in some situations (e.g. back called # from within a backpropagator) -function back!(x::Tracked, Δ) +function back!(x, Δ) + istracked(x) || return scan(x) - back(x, Δ) + back(tracker(x), Δ) + return end -back!(x, Δ) = back!(tracker(x), Δ) - # Out-of-place gradients struct Params @@ -91,7 +92,7 @@ function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs) + foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) end back_(g::Grads, ::Call{Void}, Δ) = nothing @@ -108,8 +109,7 @@ function back(g::Grads, x::Tracked, Δ) return end -back(g::Grads, x, Δ) = back(g, tracker(x), Δ) -back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`") +back(::Grads, ::Void, _) = return function forward(f, ps::Params) y = f() @@ -117,7 +117,7 @@ function forward(f, ps::Params) g = Grads() if istracked(y) scan(y) - back(g, y, Δ) + back(g, tracker(y), Δ) end for p in ps haskey(g, tracker(p)) || diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 93f9f7dc..6232807f 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -1,12 +1,14 @@ struct TrackedReal{T<:Real} <: Real + data::T tracker::Tracked{T} end -TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x))) +TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x))) +data(x::TrackedReal) = x.data tracker(x::TrackedReal) = x.tracker -track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) +track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) function back!(x::TrackedReal) isinf(x) && error("Loss is Inf") @@ -73,6 +75,7 @@ import Base:^ # Tuples struct TrackedTuple{T<:Tuple} + data::T tracker::Tracked{T} end @@ -82,7 +85,7 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) init_grad(x::Tuple) = init_grad.(x) zero_grad!(x::Tuple) = zero_grad!.(x) -track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs)) +track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f)) function Base.show(io::IO, xs::TrackedTuple) show(io, data(xs)) @@ -100,7 +103,7 @@ back(::typeof(getindex), Δ, t, i) = function collect(xs) xs = Base.collect(xs) - track(Call(collect, (xs,)), data.(xs)) + track(Call(collect, (tracker.(xs),)), data.(xs)) end function scan(c::Call{typeof(collect)}) @@ -108,5 +111,5 @@ function scan(c::Call{typeof(collect)}) end function back_(c::Call{typeof(collect)}, Δ) - foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args[1], Δ) + foreach(back, c.args[1], Δ) end