functional API

This commit is contained in:
Mike J Innes 2018-07-09 16:57:44 +01:00
parent 5e319c7395
commit 7778d17884
6 changed files with 116 additions and 11 deletions

View File

@ -1,5 +1,4 @@
julia 0.6.0
DataFlow 0.2.1
Juno
MacroTools 0.3.3
NNlib

View File

@ -1,7 +1,7 @@
module Tracker
using MacroTools
using MacroTools: @q
using MacroTools: @q, @forward
import Base: ==
@ -71,6 +71,7 @@ function update!(x, Δ)
return x
end
include("idset.jl")
include("back.jl")
include("scalar.jl")
include("array.jl")

View File

@ -52,6 +52,8 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x::Tracked, Δ)
scan(x)
@ -59,3 +61,88 @@ function back!(x::Tracked, Δ)
end
back!(x, Δ) = back!(tracker(x), Δ)
# Out-of-place gradients
struct Params
params::IdSet
Params(xs) = new(IdSet(xs))
end
@forward Params.params Base.start, Base.next, Base.done
struct Grads
grads::ObjectIdDict
end
Grads() = Grads(ObjectIdDict())
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
@forward Grads.grads Base.setindex!, Base.haskey
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs)
end
back_(g::Grads, ::Call{Void}, Δ) = nothing
function back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
return
end
back(g::Grads, x, Δ) = back(g, tracker(x), Δ)
back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`")
function forward(f, ps::Params)
y = f()
y, function (Δ)
g = Grads()
if istracked(y)
scan(y)
back(g, y, Δ)
end
for p in ps
haskey(g, tracker(p)) ||
(g[tracker(p)] = init_grad(data(p)))
end
return g
end
end
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(back(Δ), args)
end
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
function gradient(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
derivative(f, x) = gradient(f, x)[1]

25
src/tracker/idset.jl Normal file
View File

@ -0,0 +1,25 @@
struct IdSet{T} <: AbstractSet{T}
dict::ObjectIdDict
IdSet{T}() where T = new(ObjectIdDict())
end
Base.eltype{T}(::IdSet{T}) = T
IdSet() = IdSet{Any}()
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
IdSet(xs) = IdSet{eltype(xs)}(xs)
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
@forward IdSet.dict Base.length
Base.start(s::IdSet) = start(keys(s.dict))
Base.next(s::IdSet, st) = next(keys(s.dict), st)
Base.done(s::IdSet, st) = done(keys(s.dict), st)

View File

@ -1,9 +1,3 @@
function gradient(f, xs...)
xs = param.(xs)
back!(f(xs...))
grad.(xs)
end
function ngradient(f, xs::AbstractArray...)
grads = zeros.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)

View File

@ -1,4 +1,5 @@
import Adapt: adapt
import .Tracker: IdSet
children(x) = ()
mapchildren(f, x) = x
@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict())
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
using DataFlow: OSet
function prefor(f, x; seen = OSet())
function prefor(f, x; seen = IdSet())
x seen && return
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))