functional API
This commit is contained in:
parent
5e319c7395
commit
7778d17884
1
REQUIRE
1
REQUIRE
@ -1,5 +1,4 @@
|
|||||||
julia 0.6.0
|
julia 0.6.0
|
||||||
DataFlow 0.2.1
|
|
||||||
Juno
|
Juno
|
||||||
MacroTools 0.3.3
|
MacroTools 0.3.3
|
||||||
NNlib
|
NNlib
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
|
|
||||||
using MacroTools
|
using MacroTools
|
||||||
using MacroTools: @q
|
using MacroTools: @q, @forward
|
||||||
|
|
||||||
import Base: ==
|
import Base: ==
|
||||||
|
|
||||||
@ -71,6 +71,7 @@ function update!(x, Δ)
|
|||||||
return x
|
return x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
include("idset.jl")
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("scalar.jl")
|
include("scalar.jl")
|
||||||
include("array.jl")
|
include("array.jl")
|
||||||
|
@ -52,6 +52,8 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
|||||||
|
|
||||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||||
# and `back` will silently fail to update.
|
# and `back` will silently fail to update.
|
||||||
|
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||||
|
# from within a backpropagator)
|
||||||
|
|
||||||
function back!(x::Tracked, Δ)
|
function back!(x::Tracked, Δ)
|
||||||
scan(x)
|
scan(x)
|
||||||
@ -59,3 +61,88 @@ function back!(x::Tracked, Δ)
|
|||||||
end
|
end
|
||||||
|
|
||||||
back!(x, Δ) = back!(tracker(x), Δ)
|
back!(x, Δ) = back!(tracker(x), Δ)
|
||||||
|
|
||||||
|
# Out-of-place gradients
|
||||||
|
|
||||||
|
struct Params
|
||||||
|
params::IdSet
|
||||||
|
Params(xs) = new(IdSet(xs))
|
||||||
|
end
|
||||||
|
|
||||||
|
@forward Params.params Base.start, Base.next, Base.done
|
||||||
|
|
||||||
|
struct Grads
|
||||||
|
grads::ObjectIdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
Grads() = Grads(ObjectIdDict())
|
||||||
|
|
||||||
|
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||||
|
function Base.getindex(g::Grads, x)
|
||||||
|
istracked(x) || error("Object not tracked: $x")
|
||||||
|
g[tracker(x)]
|
||||||
|
end
|
||||||
|
|
||||||
|
@forward Grads.grads Base.setindex!, Base.haskey
|
||||||
|
|
||||||
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
||||||
|
|
||||||
|
function back_(g::Grads, c::Call, Δ)
|
||||||
|
Δs = c.func(Δ)
|
||||||
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||||
|
error("Gradient is not a tuple of length $(length(c.args))")
|
||||||
|
foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs)
|
||||||
|
end
|
||||||
|
|
||||||
|
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
||||||
|
|
||||||
|
function back(g::Grads, x::Tracked, Δ)
|
||||||
|
x.isleaf && (accum!(g, x, Δ); return)
|
||||||
|
ref = x.ref -= 1
|
||||||
|
if ref > 0 || haskey(g, x)
|
||||||
|
accum!(g, x, Δ)
|
||||||
|
ref == 0 && back_(g, x.f, g[x])
|
||||||
|
else
|
||||||
|
ref == 0 && back_(g, x.f, Δ)
|
||||||
|
end
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
back(g::Grads, x, Δ) = back(g, tracker(x), Δ)
|
||||||
|
back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||||
|
|
||||||
|
function forward(f, ps::Params)
|
||||||
|
y = f()
|
||||||
|
y, function (Δ)
|
||||||
|
g = Grads()
|
||||||
|
if istracked(y)
|
||||||
|
scan(y)
|
||||||
|
back(g, y, Δ)
|
||||||
|
end
|
||||||
|
for p in ps
|
||||||
|
haskey(g, tracker(p)) ||
|
||||||
|
(g[tracker(p)] = init_grad(data(p)))
|
||||||
|
end
|
||||||
|
return g
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function forward(f, args...)
|
||||||
|
args = param.(args)
|
||||||
|
y, back = forward(() -> f(args...), Params(args))
|
||||||
|
y, Δ -> getindex.(back(Δ), args)
|
||||||
|
end
|
||||||
|
|
||||||
|
function losscheck(x)
|
||||||
|
x isa Real || error("Function output is not scalar")
|
||||||
|
isinf(x) && error("Loss is infinite")
|
||||||
|
isnan(x) && error("Loss is NaN")
|
||||||
|
end
|
||||||
|
|
||||||
|
function gradient(f, args...)
|
||||||
|
y, back = forward(f, args...)
|
||||||
|
losscheck(y)
|
||||||
|
return back(1)
|
||||||
|
end
|
||||||
|
|
||||||
|
derivative(f, x) = gradient(f, x)[1]
|
||||||
|
25
src/tracker/idset.jl
Normal file
25
src/tracker/idset.jl
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
struct IdSet{T} <: AbstractSet{T}
|
||||||
|
dict::ObjectIdDict
|
||||||
|
IdSet{T}() where T = new(ObjectIdDict())
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.eltype{T}(::IdSet{T}) = T
|
||||||
|
|
||||||
|
IdSet() = IdSet{Any}()
|
||||||
|
|
||||||
|
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
|
||||||
|
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
|
||||||
|
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||||
|
|
||||||
|
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
|
||||||
|
|
||||||
|
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||||
|
|
||||||
|
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
|
||||||
|
Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
||||||
|
|
||||||
|
@forward IdSet.dict Base.length
|
||||||
|
|
||||||
|
Base.start(s::IdSet) = start(keys(s.dict))
|
||||||
|
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
||||||
|
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
@ -1,9 +1,3 @@
|
|||||||
function gradient(f, xs...)
|
|
||||||
xs = param.(xs)
|
|
||||||
back!(f(xs...))
|
|
||||||
grad.(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
function ngradient(f, xs::AbstractArray...)
|
function ngradient(f, xs::AbstractArray...)
|
||||||
grads = zeros.(xs)
|
grads = zeros.(xs)
|
||||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import Adapt: adapt
|
import Adapt: adapt
|
||||||
|
import .Tracker: IdSet
|
||||||
|
|
||||||
children(x) = ()
|
children(x) = ()
|
||||||
mapchildren(f, x) = x
|
mapchildren(f, x) = x
|
||||||
@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict())
|
|||||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||||
end
|
end
|
||||||
|
|
||||||
using DataFlow: OSet
|
function prefor(f, x; seen = IdSet())
|
||||||
|
|
||||||
function prefor(f, x; seen = OSet())
|
|
||||||
x ∈ seen && return
|
x ∈ seen && return
|
||||||
f(x)
|
f(x)
|
||||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||||
|
Loading…
Reference in New Issue
Block a user