return Params from params
This commit is contained in:
parent
4a54d30cbf
commit
554c4c7c7a
@ -66,15 +66,28 @@ end
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
params::IdSet
|
||||
Params(xs) = new(IdSet(xs))
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.params Base.iterate, Base.length
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.params, ", ")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
|
@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
|
||||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!(s::IdSet) = s
|
||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
|
||||
end
|
||||
|
||||
function params(m)
|
||||
ps = []
|
||||
ps = Params()
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||
|
Loading…
Reference in New Issue
Block a user