From 41b9412439192a93934306ebf816afe9b9b652b4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 6 Jul 2018 11:28:18 +0100 Subject: [PATCH] new grad api --- src/tracker/Tracker.jl | 29 +++++++++++++++++++++++------ src/tracker/array.jl | 4 ++-- src/tracker/back.jl | 15 ++++++++++----- src/tracker/scalar.jl | 18 ++++++++---------- 4 files changed, 43 insertions(+), 23 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 1296d179..959fc8f1 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,5 +1,7 @@ module Tracker +using MacroTools + import Base: == export TrackedArray, TrackedVector, TrackedMatrix, param, back! @@ -17,7 +19,8 @@ struct Call{F,As<:Tuple} args::As end -Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) +Call(f, args) = Call{typeof(f),typeof(args)}(f, args) +Call() = Call(nothing, ()) # When deserialising, the object_id changes a::Call == b::Call = a.func == b.func && a.args == b.args @@ -38,15 +41,29 @@ end Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) -track(f::Call, x) = Tracked(f, x) -track(f::Call) = track(f, f()) -track(f, xs...) = track(Call(f, xs...)) - istracked(x::Tracked) = true -isleaf(x::Tracked) = x.f == Call(nothing) +isleaf(x::Tracked) = x.f == Call() data(x::Tracked) = x.data grad(x::Tracked) = x.grad +track(f::Call, x) = Tracked(f, x) +track(f::Call) = track(f, f()) + +function _forward end + +function track(f, xs...) + y, back = _forward(f, data.(xs)...) + track(Call(back, xs), y) +end + +macro grad(ex) + @capture(shortdef(ex), (name_(args__) = body_) | + (name_(args__) where {T__} = body_)) || error("Need a function definition") + T == nothing && (T = []) + unshift!(args, :(::typeof($name))) + :(Tracker._forward($(args...)) where $(T...) = $body) |> esc +end + function update!(x, Δ) tracker(x).data += Δ tracker(x).grad .= 0 diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 7a54d2eb..987630c7 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -20,7 +20,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray = TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) -TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) +TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x)) Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} @@ -101,7 +101,7 @@ function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer) Δ′ = similar(xs.data) Δ′ .= 0 S = size(xs.data) - + # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in enumerate(IndexCartesian(), Δ) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 60b12868..5bf13d56 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -21,9 +21,14 @@ function scan(x) return end -back_(f, y, args...) = back(f, args...) -back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) -back_(::Call{Void}, y, Δ) = nothing +function back_(c::Call, Δ) + Δs = c.func(Δ) + (Δs isa Tuple && length(Δs) == length(c.args)) || + error("Gradient is not a tuple of length $(length(c.args))") + foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs) +end + +back_(::Call{Void}, Δ) = nothing accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) @@ -33,9 +38,9 @@ function back(x::Tracked, Δ) ref = x.ref -= 1 if isdefined(x, :grad) x.grad = accum!(x.grad, Δ) - ref == 0 && back_(x.f, x.data, x.grad) + ref == 0 && back_(x.f, x.grad) else - ref == 0 && back_(x.f, x.data, Δ) + ref == 0 && back_(x.f, Δ) end return end diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 773943c0..9d2d724f 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -2,7 +2,7 @@ struct TrackedReal{T<:Real} <: Real tracker::Tracked{T} end -TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x))) +TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x))) tracker(x::TrackedReal) = x.tracker @@ -47,23 +47,21 @@ using DiffRules, SpecialFunctions, NaNMath for (M, f, arity) in DiffRules.diffrules() arity == 1 || continue @eval begin + @grad $M.$f(a::Real) = + $M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),) $M.$f(a::TrackedReal) = track($M.$f, a) - back(::typeof($M.$f), Δ::Real, a::TrackedReal) = - back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a))))) end end for (M, f, arity) in DiffRules.diffrules() arity == 2 || continue da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) + f = :($M.$f) @eval begin - $M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b) - $M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b) - $M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b) - function back(::typeof($M.$f), Δ::Real, a::Real, b::Real) - @back(a, Δ * $da) - @back(b, Δ * $db) - end + @grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db) + $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) + $f(a::TrackedReal, b::Real) = track($f, a, b) + $f(a::Real, b::TrackedReal) = track($f, a, b) end end