0.7 fix
This commit is contained in:
parent
c21d768b7c
commit
b77433cdfd
@ -26,8 +26,8 @@ function back_(c::Call, Δ, once)
|
|||||||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(::Call{Nothing}, _, _) = nothing
|
back_(::Call{Nothing}, Δ, once) = nothing
|
||||||
back_(::Call{Missing}, _, _) = error("`back!` was already used")
|
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
||||||
|
|
||||||
accum!(x, Δ) = x .+ Δ
|
accum!(x, Δ) = x .+ Δ
|
||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
@ -49,7 +49,7 @@ function back(x::Tracked, Δ, once)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(::Nothing, _, _) = return
|
back(::Nothing, Δ, once) = return
|
||||||
|
|
||||||
# Interface methods
|
# Interface methods
|
||||||
|
|
||||||
@ -94,12 +94,12 @@ Grads() = Grads(IdDict())
|
|||||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||||
|
|
||||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||||
|
|
||||||
function Base.getindex(g::Grads, x)
|
function Base.getindex(g::Grads, x)
|
||||||
istracked(x) || error("Object not tracked: $x")
|
istracked(x) || error("Object not tracked: $x")
|
||||||
g[tracker(x)]
|
g[tracker(x)]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||||
|
|
||||||
function back_(g::Grads, c::Call, Δ)
|
function back_(g::Grads, c::Call, Δ)
|
||||||
|
Loading…
Reference in New Issue
Block a user