return Params from params

This commit is contained in:
Mike J Innes 2018-10-31 15:50:08 +00:00
parent 4a54d30cbf
commit 554c4c7c7a
3 changed files with 19 additions and 5 deletions

View File

@ -66,15 +66,28 @@ end
# Out-of-place gradients
struct Params
params::IdSet
Params(xs) = new(IdSet(xs))
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.params Base.iterate, Base.length
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.params, ", ")
join(io, ps.order, ", ")
print(io, "])")
end

View File

@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}()
Base.push!(s::IdSet) = s
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)

View File

@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
end
function params(m)
ps = []
ps = Params()
prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) &&
!any(p -> p === p, ps) && push!(ps, p),