functional API
This commit is contained in:
parent
5e319c7395
commit
7778d17884
@ -1,7 +1,7 @@
|
||||
module Tracker
|
||||
|
||||
using MacroTools
|
||||
using MacroTools: @q
|
||||
using MacroTools: @q, @forward
|
||||
|
||||
import Base: ==
|
||||
|
||||
@ -71,6 +71,7 @@ function update!(x, Δ)
|
||||
return x
|
||||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("back.jl")
|
||||
include("scalar.jl")
|
||||
include("array.jl")
|
||||
|
@ -52,6 +52,8 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x::Tracked, Δ)
|
||||
scan(x)
|
||||
@ -59,3 +61,88 @@ function back!(x::Tracked, Δ)
|
||||
end
|
||||
|
||||
back!(x, Δ) = back!(tracker(x), Δ)
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
params::IdSet
|
||||
Params(xs) = new(IdSet(xs))
|
||||
end
|
||||
|
||||
@forward Params.params Base.start, Base.next, Base.done
|
||||
|
||||
struct Grads
|
||||
grads::ObjectIdDict
|
||||
end
|
||||
|
||||
Grads() = Grads(ObjectIdDict())
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
x.isleaf && (accum!(g, x, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if ref > 0 || haskey(g, x)
|
||||
accum!(g, x, Δ)
|
||||
ref == 0 && back_(g, x.f, g[x])
|
||||
else
|
||||
ref == 0 && back_(g, x.f, Δ)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(g::Grads, x, Δ) = back(g, tracker(x), Δ)
|
||||
back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
y, function (Δ)
|
||||
g = Grads()
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
back(g, y, Δ)
|
||||
end
|
||||
for p in ps
|
||||
haskey(g, tracker(p)) ||
|
||||
(g[tracker(p)] = init_grad(data(p)))
|
||||
end
|
||||
return g
|
||||
end
|
||||
end
|
||||
|
||||
function forward(f, args...)
|
||||
args = param.(args)
|
||||
y, back = forward(() -> f(args...), Params(args))
|
||||
y, Δ -> getindex.(back(Δ), args)
|
||||
end
|
||||
|
||||
function losscheck(x)
|
||||
x isa Real || error("Function output is not scalar")
|
||||
isinf(x) && error("Loss is infinite")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
end
|
||||
|
||||
function gradient(f, args...)
|
||||
y, back = forward(f, args...)
|
||||
losscheck(y)
|
||||
return back(1)
|
||||
end
|
||||
|
||||
derivative(f, x) = gradient(f, x)[1]
|
||||
|
25
src/tracker/idset.jl
Normal file
25
src/tracker/idset.jl
Normal file
@ -0,0 +1,25 @@
|
||||
struct IdSet{T} <: AbstractSet{T}
|
||||
dict::ObjectIdDict
|
||||
IdSet{T}() where T = new(ObjectIdDict())
|
||||
end
|
||||
|
||||
Base.eltype{T}(::IdSet{T}) = T
|
||||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
|
||||
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
||||
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
|
||||
|
||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||
|
||||
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
|
||||
Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
||||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
Base.start(s::IdSet) = start(keys(s.dict))
|
||||
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
||||
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
@ -1,9 +1,3 @@
|
||||
function gradient(f, xs...)
|
||||
xs = param.(xs)
|
||||
back!(f(xs...))
|
||||
grad.(xs)
|
||||
end
|
||||
|
||||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zeros.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import Adapt: adapt
|
||||
import .Tracker: IdSet
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict())
|
||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
using DataFlow: OSet
|
||||
|
||||
function prefor(f, x; seen = OSet())
|
||||
function prefor(f, x; seen = IdSet())
|
||||
x ∈ seen && return
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
|
Loading…
Reference in New Issue
Block a user