From 7cfc42d166ab272eed0f0fd8e87f8273ad539200 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 6 Sep 2017 21:21:35 -0400 Subject: [PATCH] grad refactor --- src/tracker/Tracker.jl | 41 ++++++++++++----------------------------- src/tracker/back.jl | 16 ++++++++++++++++ src/tracker/lib.jl | 12 ++++++------ 3 files changed, 34 insertions(+), 35 deletions(-) create mode 100644 src/tracker/back.jl diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 9510febc..ebc38f35 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,5 +1,7 @@ module Tracker +using Base: RefValue + export track, back! data(x) = x @@ -14,13 +16,10 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) (c::Call)() = c.func(data.(c.args)...) -back!(c::Call, Δ) = back!(c.func, Δ, c.args...) -back!(::Call{Void}, Δ) = nothing - struct TrackedArray{T,N,A} <: AbstractArray{T,N} f::Call - x::A - Δ::A + data::A + grad::RefValue{A} end TrackedScalar{T,A} = TrackedArray{T,0,A} @@ -28,36 +27,19 @@ TrackedVector{T,A} = TrackedArray{T,1,A} TrackedMatrix{T,A} = TrackedArray{T,2,A} TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} -TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = +TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray = TrackedArray{eltype(A),ndims(A),A}(c, x, Δ) -TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, zeros(x)) +TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}()) TrackedArray(c::Call) = TrackedArray(c, c()) -TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x) +TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x))) -track(xs) = TrackedArray(xs) +track(xs) = TrackedArray(AbstractFloat.(xs)) istracked(x::TrackedArray) = true -data(x::TrackedArray) = x.x -grad(x::TrackedArray) = x.Δ - -tovec(xs::AbstractArray) = vec(xs) -tovec(xs) = xs - -function back!(x::TrackedArray, Δ) - x.Δ .+= Δ - back!(x.f, Δ) -end - -back!(x::TrackedScalar) = back!(x, 1) - -macro back!(x, Δ) - quote - x = $(esc(x)) - istracked(x) && back!(x, $(esc(Δ))) - end -end +data(x::TrackedArray) = x.data +grad(x::TrackedArray) = x.grad[] # Fallthrough methods @@ -84,6 +66,7 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru end end +include("back.jl") include("lib.jl") include("numeric.jl") @@ -91,7 +74,7 @@ using Requires @require CuArrays begin import CuArrays: cu - cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.x), cu(xs.Δ)) + cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs)))) end end diff --git a/src/tracker/back.jl b/src/tracker/back.jl new file mode 100644 index 00000000..42d70001 --- /dev/null +++ b/src/tracker/back.jl @@ -0,0 +1,16 @@ +back!(c::Call, Δ) = back!(c.func, Δ, c.args...) +back!(::Call{Void}, Δ) = nothing + +function back!(x::TrackedArray, Δ) + isassigned(x.grad) && (x.grad[] .+= Δ) + back!(x.f, Δ) +end + +back!(x::TrackedScalar) = back!(x, 1) + +macro back!(x, Δ) + quote + x = $(esc(x)) + istracked(x) && back!(x, $(esc(Δ))) + end +end diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index d57669ab..71d2ed0f 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -7,10 +7,10 @@ unarray(xs) = xs unarray(xs::AbstractArray{T,0} where T) = xs[] Base.getindex(xs::TrackedArray, i...) = - TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...])) + TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...])) function back!(::typeof(getindex), Δ, xs::TrackedArray, i...) - Δ′ = zeros(xs.x) + Δ′ = zeros(xs.data) Δ′[i...] = unarray(Δ) @back!(xs, Δ′) end @@ -49,13 +49,13 @@ end # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) -Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x))) +Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data))) Base.sum(xs::TrackedScalar, dim...) = xs -back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ) +back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.data) .= Δ) -Base.maximum(xs::TrackedArray, args...) = maximum(xs.x, args...) -Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.x, args...) +Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) +Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) # BLAS