new grad api

This commit is contained in:
Mike J Innes 2018-07-06 11:28:18 +01:00
parent ce88273880
commit 41b9412439
4 changed files with 43 additions and 23 deletions

View File

@ -1,5 +1,7 @@
module Tracker module Tracker
using MacroTools
import Base: == import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, param, back! export TrackedArray, TrackedVector, TrackedMatrix, param, back!
@ -17,7 +19,8 @@ struct Call{F,As<:Tuple}
args::As args::As
end end
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
Call() = Call(nothing, ())
# When deserialising, the object_id changes # When deserialising, the object_id changes
a::Call == b::Call = a.func == b.func && a.args == b.args a::Call == b::Call = a.func == b.func && a.args == b.args
@ -38,15 +41,29 @@ end
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
track(f::Call, x) = Tracked(f, x)
track(f::Call) = track(f, f())
track(f, xs...) = track(Call(f, xs...))
istracked(x::Tracked) = true istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call(nothing) isleaf(x::Tracked) = x.f == Call()
data(x::Tracked) = x.data data(x::Tracked) = x.data
grad(x::Tracked) = x.grad grad(x::Tracked) = x.grad
track(f::Call, x) = Tracked(f, x)
track(f::Call) = track(f, f())
function _forward end
function track(f, xs...)
y, back = _forward(f, data.(xs)...)
track(Call(back, xs), y)
end
macro grad(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {T__} = body_)) || error("Need a function definition")
T == nothing && (T = [])
unshift!(args, :(::typeof($name)))
:(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
function update!(x, Δ) function update!(x, Δ)
tracker(x).data += Δ tracker(x).data += Δ
tracker(x).grad .= 0 tracker(x).grad .= 0

View File

@ -20,7 +20,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}

View File

@ -21,9 +21,14 @@ function scan(x)
return return
end end
back_(f, y, args...) = back(f, args...) function back_(c::Call, Δ)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) Δs = c.func(Δ)
back_(::Call{Void}, y, Δ) = nothing (Δs isa Tuple && length(Δs) == length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs)
end
back_(::Call{Void}, Δ) = nothing
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
@ -33,9 +38,9 @@ function back(x::Tracked, Δ)
ref = x.ref -= 1 ref = x.ref -= 1
if isdefined(x, :grad) if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ) x.grad = accum!(x.grad, Δ)
ref == 0 && back_(x.f, x.data, x.grad) ref == 0 && back_(x.f, x.grad)
else else
ref == 0 && back_(x.f, x.data, Δ) ref == 0 && back_(x.f, Δ)
end end
return return
end end

View File

@ -2,7 +2,7 @@ struct TrackedReal{T<:Real} <: Real
tracker::Tracked{T} tracker::Tracked{T}
end end
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x))) TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x)))
tracker(x::TrackedReal) = x.tracker tracker(x::TrackedReal) = x.tracker
@ -47,23 +47,21 @@ using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue arity == 1 || continue
@eval begin @eval begin
@grad $M.$f(a::Real) =
$M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),)
$M.$f(a::TrackedReal) = track($M.$f, a) $M.$f(a::TrackedReal) = track($M.$f, a)
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
end end
end end
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
f = :($M.$f)
@eval begin @eval begin
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b) @grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db)
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b)
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real) $f(a::Real, b::TrackedReal) = track($f, a, b)
@back(a, Δ * $da)
@back(b, Δ * $db)
end
end end
end end