new grad api
This commit is contained in:
parent
ce88273880
commit
41b9412439
@ -1,5 +1,7 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
|
|
||||||
|
using MacroTools
|
||||||
|
|
||||||
import Base: ==
|
import Base: ==
|
||||||
|
|
||||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||||
@ -17,7 +19,8 @@ struct Call{F,As<:Tuple}
|
|||||||
args::As
|
args::As
|
||||||
end
|
end
|
||||||
|
|
||||||
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
|
||||||
|
Call() = Call(nothing, ())
|
||||||
|
|
||||||
# When deserialising, the object_id changes
|
# When deserialising, the object_id changes
|
||||||
a::Call == b::Call = a.func == b.func && a.args == b.args
|
a::Call == b::Call = a.func == b.func && a.args == b.args
|
||||||
@ -38,15 +41,29 @@ end
|
|||||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||||
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
||||||
|
|
||||||
track(f::Call, x) = Tracked(f, x)
|
|
||||||
track(f::Call) = track(f, f())
|
|
||||||
track(f, xs...) = track(Call(f, xs...))
|
|
||||||
|
|
||||||
istracked(x::Tracked) = true
|
istracked(x::Tracked) = true
|
||||||
isleaf(x::Tracked) = x.f == Call(nothing)
|
isleaf(x::Tracked) = x.f == Call()
|
||||||
data(x::Tracked) = x.data
|
data(x::Tracked) = x.data
|
||||||
grad(x::Tracked) = x.grad
|
grad(x::Tracked) = x.grad
|
||||||
|
|
||||||
|
track(f::Call, x) = Tracked(f, x)
|
||||||
|
track(f::Call) = track(f, f())
|
||||||
|
|
||||||
|
function _forward end
|
||||||
|
|
||||||
|
function track(f, xs...)
|
||||||
|
y, back = _forward(f, data.(xs)...)
|
||||||
|
track(Call(back, xs), y)
|
||||||
|
end
|
||||||
|
|
||||||
|
macro grad(ex)
|
||||||
|
@capture(shortdef(ex), (name_(args__) = body_) |
|
||||||
|
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
||||||
|
T == nothing && (T = [])
|
||||||
|
unshift!(args, :(::typeof($name)))
|
||||||
|
:(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||||
|
end
|
||||||
|
|
||||||
function update!(x, Δ)
|
function update!(x, Δ)
|
||||||
tracker(x).data += Δ
|
tracker(x).data += Δ
|
||||||
tracker(x).grad .= 0
|
tracker(x).grad .= 0
|
||||||
|
@ -20,7 +20,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||||
|
|
||||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||||
|
|
||||||
|
@ -21,9 +21,14 @@ function scan(x)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(f, y, args...) = back(f, args...)
|
function back_(c::Call, Δ)
|
||||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
Δs = c.func(Δ)
|
||||||
back_(::Call{Void}, y, Δ) = nothing
|
(Δs isa Tuple && length(Δs) == length(c.args)) ||
|
||||||
|
error("Gradient is not a tuple of length $(length(c.args))")
|
||||||
|
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs)
|
||||||
|
end
|
||||||
|
|
||||||
|
back_(::Call{Void}, Δ) = nothing
|
||||||
|
|
||||||
accum!(x, Δ) = x .+ Δ
|
accum!(x, Δ) = x .+ Δ
|
||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
@ -33,9 +38,9 @@ function back(x::Tracked, Δ)
|
|||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
x.grad = accum!(x.grad, Δ)
|
x.grad = accum!(x.grad, Δ)
|
||||||
ref == 0 && back_(x.f, x.data, x.grad)
|
ref == 0 && back_(x.f, x.grad)
|
||||||
else
|
else
|
||||||
ref == 0 && back_(x.f, x.data, Δ)
|
ref == 0 && back_(x.f, Δ)
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
@ -2,7 +2,7 @@ struct TrackedReal{T<:Real} <: Real
|
|||||||
tracker::Tracked{T}
|
tracker::Tracked{T}
|
||||||
end
|
end
|
||||||
|
|
||||||
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
|
TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x)))
|
||||||
|
|
||||||
tracker(x::TrackedReal) = x.tracker
|
tracker(x::TrackedReal) = x.tracker
|
||||||
|
|
||||||
@ -47,23 +47,21 @@ using DiffRules, SpecialFunctions, NaNMath
|
|||||||
for (M, f, arity) in DiffRules.diffrules()
|
for (M, f, arity) in DiffRules.diffrules()
|
||||||
arity == 1 || continue
|
arity == 1 || continue
|
||||||
@eval begin
|
@eval begin
|
||||||
|
@grad $M.$f(a::Real) =
|
||||||
|
$M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),)
|
||||||
$M.$f(a::TrackedReal) = track($M.$f, a)
|
$M.$f(a::TrackedReal) = track($M.$f, a)
|
||||||
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
|
|
||||||
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for (M, f, arity) in DiffRules.diffrules()
|
for (M, f, arity) in DiffRules.diffrules()
|
||||||
arity == 2 || continue
|
arity == 2 || continue
|
||||||
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
||||||
|
f = :($M.$f)
|
||||||
@eval begin
|
@eval begin
|
||||||
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
|
@grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db)
|
||||||
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
|
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||||
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
|
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||||
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
|
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||||
@back(a, Δ * $da)
|
|
||||||
@back(b, Δ * $db)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user