organise params

This commit is contained in:
Mike Innes 2019-02-28 13:44:54 +00:00
parent d6cf116a74
commit 8b4bc7cc52
3 changed files with 49 additions and 47 deletions

View File

@ -62,6 +62,7 @@ macro grad(ex)
end end
include("idset.jl") include("idset.jl")
include("params.jl")
include("back.jl") include("back.jl")
include("numeric.jl") include("numeric.jl")
include("lib/real.jl") include("lib/real.jl")

View File

@ -1,3 +1,5 @@
# In-place gradients
init_grad(x) = zero(x) init_grad(x) = zero(x)
zero_grad!(x) = zero(x) zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0) zero_grad!(x::AbstractArray) = (x .= 0)
@ -77,53 +79,6 @@ end
# Out-of-place gradients # Out-of-place gradients
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||

46
src/tracker/params.jl Normal file
View File

@ -0,0 +1,46 @@
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ