return Params from params
This commit is contained in:
parent
4a54d30cbf
commit
554c4c7c7a
@ -66,15 +66,28 @@ end
|
|||||||
# Out-of-place gradients
|
# Out-of-place gradients
|
||||||
|
|
||||||
struct Params
|
struct Params
|
||||||
params::IdSet
|
order::Vector{Any}
|
||||||
Params(xs) = new(IdSet(xs))
|
params::IdSet{Any}
|
||||||
|
Params() = new([], IdSet())
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Params.params Base.iterate, Base.length
|
@forward Params.order Base.iterate, Base.length
|
||||||
|
|
||||||
|
function Base.push!(ps::Params, x)
|
||||||
|
if !(x in ps.params)
|
||||||
|
push!(ps.order, x)
|
||||||
|
push!(ps.params, x)
|
||||||
|
end
|
||||||
|
return ps
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||||
|
|
||||||
|
Params(xs) = push!(Params(), xs...)
|
||||||
|
|
||||||
function Base.show(io::IO, ps::Params)
|
function Base.show(io::IO, ps::Params)
|
||||||
print(io, "Params([")
|
print(io, "Params([")
|
||||||
join(io, ps.params, ", ")
|
join(io, ps.order, ", ")
|
||||||
print(io, "])")
|
print(io, "])")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
|
|||||||
|
|
||||||
IdSet() = IdSet{Any}()
|
IdSet() = IdSet{Any}()
|
||||||
|
|
||||||
|
Base.push!(s::IdSet) = s
|
||||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||||
|
@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
|
|||||||
end
|
end
|
||||||
|
|
||||||
function params(m)
|
function params(m)
|
||||||
ps = []
|
ps = Params()
|
||||||
prefor(p ->
|
prefor(p ->
|
||||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||||
|
Loading…
Reference in New Issue
Block a user