0.7 fix
This commit is contained in:
parent
c21d768b7c
commit
b77433cdfd
|
@ -26,8 +26,8 @@ function back_(c::Call, Δ, once)
|
|||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Nothing}, _, _) = nothing
|
||||
back_(::Call{Missing}, _, _) = error("`back!` was already used")
|
||||
back_(::Call{Nothing}, Δ, once) = nothing
|
||||
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
@ -49,7 +49,7 @@ function back(x::Tracked, Δ, once)
|
|||
return
|
||||
end
|
||||
|
||||
back(::Nothing, _, _) = return
|
||||
back(::Nothing, Δ, once) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
|
@ -94,12 +94,12 @@ Grads() = Grads(IdDict())
|
|||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
|
|
Loading…
Reference in New Issue