From 554c4c7c7ac3be1c7e77b1a7693bf905122e13de Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 31 Oct 2018 15:50:08 +0000 Subject: [PATCH] return Params from params --- src/tracker/back.jl | 21 +++++++++++++++++---- src/tracker/idset.jl | 1 + src/treelike.jl | 2 +- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index e5a84a71..2be772b0 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -66,15 +66,28 @@ end # Out-of-place gradients struct Params - params::IdSet - Params(xs) = new(IdSet(xs)) + order::Vector{Any} + params::IdSet{Any} + Params() = new([], IdSet()) end -@forward Params.params Base.iterate, Base.length +@forward Params.order Base.iterate, Base.length + +function Base.push!(ps::Params, x) + if !(x in ps.params) + push!(ps.order, x) + push!(ps.params, x) + end + return ps +end + +Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) + +Params(xs) = push!(Params(), xs...) function Base.show(io::IO, ps::Params) print(io, "Params([") - join(io, ps.params, ", ") + join(io, ps.order, ", ") print(io, "])") end diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl index 62570c99..372e262a 100644 --- a/src/tracker/idset.jl +++ b/src/tracker/idset.jl @@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T IdSet() = IdSet{Any}() +Base.push!(s::IdSet) = s Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) Base.in(x, s::IdSet) = haskey(s.dict, x) diff --git a/src/treelike.jl b/src/treelike.jl index 3d83d448..ae94590b 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet()) end function params(m) - ps = [] + ps = Params() prefor(p -> Tracker.istracked(p) && Tracker.isleaf(p) && !any(p′ -> p′ === p, ps) && push!(ps, p),