diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 06f62e5d..9ca0377c 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -79,14 +79,14 @@ function Base.show(io::IO, ps::Params) end struct Grads - grads::ObjectIdDict + grads::IdDict{Any,Any} end Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") -Grads() = Grads(ObjectIdDict()) +Grads() = Grads(IdDict()) -Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps)) +Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) Base.getindex(g::Grads, x::Tracked) = g.grads[x] function Base.getindex(g::Grads, x) diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl index 68d1eea1..0d5fade9 100644 --- a/src/tracker/idset.jl +++ b/src/tracker/idset.jl @@ -1,6 +1,6 @@ struct IdSet{T} <: AbstractSet{T} - dict::ObjectIdDict - IdSet{T}() where T = new(ObjectIdDict()) + dict::IdDict{T,Nothing} + IdSet{T}() where T = new(IdDict{T,Nothing}()) end Base.eltype{T}(::IdSet{T}) = T diff --git a/src/treelike.jl b/src/treelike.jl index e65ac41a..e4c4e33f 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -16,7 +16,7 @@ end isleaf(x) = isempty(children(x)) -function mapleaves(f, x; cache = ObjectIdDict()) +function mapleaves(f, x; cache = IdDict()) haskey(cache, x) && return cache[x] cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) end