organise params
This commit is contained in:
parent
d6cf116a74
commit
8b4bc7cc52
@ -62,6 +62,7 @@ macro grad(ex)
|
|||||||
end
|
end
|
||||||
|
|
||||||
include("idset.jl")
|
include("idset.jl")
|
||||||
|
include("params.jl")
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
include("lib/real.jl")
|
include("lib/real.jl")
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# In-place gradients
|
||||||
|
|
||||||
init_grad(x) = zero(x)
|
init_grad(x) = zero(x)
|
||||||
zero_grad!(x) = zero(x)
|
zero_grad!(x) = zero(x)
|
||||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||||
@ -77,53 +79,6 @@ end
|
|||||||
|
|
||||||
# Out-of-place gradients
|
# Out-of-place gradients
|
||||||
|
|
||||||
struct Params
|
|
||||||
order::Vector{Any}
|
|
||||||
params::IdSet{Any}
|
|
||||||
Params() = new([], IdSet())
|
|
||||||
end
|
|
||||||
|
|
||||||
@forward Params.order Base.iterate, Base.length
|
|
||||||
|
|
||||||
function Base.push!(ps::Params, x)
|
|
||||||
if !(x in ps.params)
|
|
||||||
push!(ps.order, x)
|
|
||||||
push!(ps.params, x)
|
|
||||||
end
|
|
||||||
return ps
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
|
||||||
|
|
||||||
Params(xs) = push!(Params(), xs...)
|
|
||||||
|
|
||||||
function Base.show(io::IO, ps::Params)
|
|
||||||
print(io, "Params([")
|
|
||||||
join(io, ps.order, ", ")
|
|
||||||
print(io, "])")
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Grads
|
|
||||||
grads::IdDict{Any,Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
|
||||||
|
|
||||||
Grads() = Grads(IdDict())
|
|
||||||
|
|
||||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
|
||||||
|
|
||||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
|
||||||
|
|
||||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
|
||||||
|
|
||||||
function Base.getindex(g::Grads, x)
|
|
||||||
istracked(x) || error("Object not tracked: $x")
|
|
||||||
g[tracker(x)]
|
|
||||||
end
|
|
||||||
|
|
||||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|
||||||
|
|
||||||
function back_(g::Grads, c::Call, Δ)
|
function back_(g::Grads, c::Call, Δ)
|
||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||||
|
46
src/tracker/params.jl
Normal file
46
src/tracker/params.jl
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
struct Params
|
||||||
|
order::Vector{Any}
|
||||||
|
params::IdSet{Any}
|
||||||
|
Params() = new([], IdSet())
|
||||||
|
end
|
||||||
|
|
||||||
|
@forward Params.order Base.iterate, Base.length
|
||||||
|
|
||||||
|
function Base.push!(ps::Params, x)
|
||||||
|
if !(x in ps.params)
|
||||||
|
push!(ps.order, x)
|
||||||
|
push!(ps.params, x)
|
||||||
|
end
|
||||||
|
return ps
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||||
|
|
||||||
|
Params(xs) = push!(Params(), xs...)
|
||||||
|
|
||||||
|
function Base.show(io::IO, ps::Params)
|
||||||
|
print(io, "Params([")
|
||||||
|
join(io, ps.order, ", ")
|
||||||
|
print(io, "])")
|
||||||
|
end
|
||||||
|
|
||||||
|
struct Grads
|
||||||
|
grads::IdDict{Any,Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||||
|
|
||||||
|
Grads() = Grads(IdDict())
|
||||||
|
|
||||||
|
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||||
|
|
||||||
|
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||||
|
|
||||||
|
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||||
|
|
||||||
|
function Base.getindex(g::Grads, x)
|
||||||
|
istracked(x) || error("Object not tracked: $x")
|
||||||
|
g[tracker(x)]
|
||||||
|
end
|
||||||
|
|
||||||
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
Loading…
Reference in New Issue
Block a user