diff --git a/REQUIRE b/REQUIRE index 8bb92ddb..95fda02c 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,5 +1,4 @@ julia 0.6.0 -DataFlow 0.2.1 Juno MacroTools 0.3.3 NNlib diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 4a58df29..d5f7dcfb 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,7 +1,7 @@ module Tracker using MacroTools -using MacroTools: @q +using MacroTools: @q, @forward import Base: == @@ -71,6 +71,7 @@ function update!(x, Δ) return x end +include("idset.jl") include("back.jl") include("scalar.jl") include("array.jl") diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 3d769778..62cae1d0 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -52,6 +52,8 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`") # TODO: if an error occurs in `back` the refcounts will be broken # and `back` will silently fail to update. +# Refcounts are also probably not safe in some situations (e.g. back called +# from within a backpropagator) function back!(x::Tracked, Δ) scan(x) @@ -59,3 +61,88 @@ function back!(x::Tracked, Δ) end back!(x, Δ) = back!(tracker(x), Δ) + +# Out-of-place gradients + +struct Params + params::IdSet + Params(xs) = new(IdSet(xs)) +end + +@forward Params.params Base.start, Base.next, Base.done + +struct Grads + grads::ObjectIdDict +end + +Grads() = Grads(ObjectIdDict()) + +Base.getindex(g::Grads, x::Tracked) = g.grads[x] +function Base.getindex(g::Grads, x) + istracked(x) || error("Object not tracked: $x") + g[tracker(x)] +end + +@forward Grads.grads Base.setindex!, Base.haskey + +accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ + +function back_(g::Grads, c::Call, Δ) + Δs = c.func(Δ) + (Δs isa Tuple && length(Δs) >= length(c.args)) || + error("Gradient is not a tuple of length $(length(c.args))") + foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs) +end + +back_(g::Grads, ::Call{Void}, Δ) = nothing + +function back(g::Grads, x::Tracked, Δ) + x.isleaf && (accum!(g, x, Δ); return) + ref = x.ref -= 1 + if ref > 0 || haskey(g, x) + accum!(g, x, Δ) + ref == 0 && back_(g, x.f, g[x]) + else + ref == 0 && back_(g, x.f, Δ) + end + return +end + +back(g::Grads, x, Δ) = back(g, tracker(x), Δ) +back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`") + +function forward(f, ps::Params) + y = f() + y, function (Δ) + g = Grads() + if istracked(y) + scan(y) + back(g, y, Δ) + end + for p in ps + haskey(g, tracker(p)) || + (g[tracker(p)] = init_grad(data(p))) + end + return g + end +end + +function forward(f, args...) + args = param.(args) + y, back = forward(() -> f(args...), Params(args)) + y, Δ -> getindex.(back(Δ), args) +end + +function losscheck(x) + x isa Real || error("Function output is not scalar") + isinf(x) && error("Loss is infinite") + isnan(x) && error("Loss is NaN") +end + +function gradient(f, args...) + y, back = forward(f, args...) + losscheck(y) + return back(1) +end + +derivative(f, x) = gradient(f, x)[1] diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl new file mode 100644 index 00000000..68d1eea1 --- /dev/null +++ b/src/tracker/idset.jl @@ -0,0 +1,25 @@ +struct IdSet{T} <: AbstractSet{T} + dict::ObjectIdDict + IdSet{T}() where T = new(ObjectIdDict()) +end + +Base.eltype{T}(::IdSet{T}) = T + +IdSet() = IdSet{Any}() + +Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s) +Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s) +Base.in(x, s::IdSet) = haskey(s.dict, x) + +(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...) + +IdSet(xs) = IdSet{eltype(xs)}(xs) + +Base.collect(s::IdSet) = Base.collect(keys(s.dict)) +Base.similar(s::IdSet, T::Type) = IdSet{T}() + +@forward IdSet.dict Base.length + +Base.start(s::IdSet) = start(keys(s.dict)) +Base.next(s::IdSet, st) = next(keys(s.dict), st) +Base.done(s::IdSet, st) = done(keys(s.dict), st) diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 755e1f7d..e0028b7c 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -1,9 +1,3 @@ -function gradient(f, xs...) - xs = param.(xs) - back!(f(xs...)) - grad.(xs) -end - function ngradient(f, xs::AbstractArray...) grads = zeros.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) diff --git a/src/treelike.jl b/src/treelike.jl index fbe9fcad..13e562e6 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -1,4 +1,5 @@ import Adapt: adapt +import .Tracker: IdSet children(x) = () mapchildren(f, x) = x @@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict()) cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) end -using DataFlow: OSet - -function prefor(f, x; seen = OSet()) +function prefor(f, x; seen = IdSet()) x ∈ seen && return f(x) foreach(x -> prefor(f, x, seen = seen), children(x))