From 8b4bc7cc5245c6c8cd2a30ad872b9b1800c2e7e9 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 28 Feb 2019 13:44:54 +0000 Subject: [PATCH] organise params --- src/tracker/Tracker.jl | 1 + src/tracker/back.jl | 49 ++---------------------------------------- src/tracker/params.jl | 46 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 47 deletions(-) create mode 100644 src/tracker/params.jl diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 2fbe6437..adceea61 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -62,6 +62,7 @@ macro grad(ex) end include("idset.jl") +include("params.jl") include("back.jl") include("numeric.jl") include("lib/real.jl") diff --git a/src/tracker/back.jl b/src/tracker/back.jl index ef65ecb6..0dda0082 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,3 +1,5 @@ +# In-place gradients + init_grad(x) = zero(x) zero_grad!(x) = zero(x) zero_grad!(x::AbstractArray) = (x .= 0) @@ -77,53 +79,6 @@ end # Out-of-place gradients -struct Params - order::Vector{Any} - params::IdSet{Any} - Params() = new([], IdSet()) -end - -@forward Params.order Base.iterate, Base.length - -function Base.push!(ps::Params, x) - if !(x in ps.params) - push!(ps.order, x) - push!(ps.params, x) - end - return ps -end - -Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) - -Params(xs) = push!(Params(), xs...) - -function Base.show(io::IO, ps::Params) - print(io, "Params([") - join(io, ps.order, ", ") - print(io, "])") -end - -struct Grads - grads::IdDict{Any,Any} -end - -Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") - -Grads() = Grads(IdDict()) - -@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate - -Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) - -Base.getindex(g::Grads, x::Tracked) = g.grads[x] - -function Base.getindex(g::Grads, x) - istracked(x) || error("Object not tracked: $x") - g[tracker(x)] -end - -accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ - function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || diff --git a/src/tracker/params.jl b/src/tracker/params.jl new file mode 100644 index 00000000..7a1db1e9 --- /dev/null +++ b/src/tracker/params.jl @@ -0,0 +1,46 @@ +struct Params + order::Vector{Any} + params::IdSet{Any} + Params() = new([], IdSet()) +end + +@forward Params.order Base.iterate, Base.length + +function Base.push!(ps::Params, x) + if !(x in ps.params) + push!(ps.order, x) + push!(ps.params, x) + end + return ps +end + +Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) + +Params(xs) = push!(Params(), xs...) + +function Base.show(io::IO, ps::Params) + print(io, "Params([") + join(io, ps.order, ", ") + print(io, "])") +end + +struct Grads + grads::IdDict{Any,Any} +end + +Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") + +Grads() = Grads(IdDict()) + +@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate + +Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) + +Base.getindex(g::Grads, x::Tracked) = g.grads[x] + +function Base.getindex(g::Grads, x) + istracked(x) || error("Object not tracked: $x") + g[tracker(x)] +end + +accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ