fix
This commit is contained in:
parent
d09eb8a8e7
commit
4320738d87
@ -29,7 +29,7 @@ accum!(x, Δ) = x .+ Δ
|
|||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
|
|
||||||
function back(x::Tracked, Δ)
|
function back(x::Tracked, Δ)
|
||||||
x.isleaf && (accum!(x.grad, Δ); return)
|
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
x.grad = accum!(x.grad, Δ)
|
x.grad = accum!(x.grad, Δ)
|
||||||
|
@ -104,4 +104,8 @@ end
|
|||||||
|
|
||||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||||
|
|
||||||
|
b = param(rand())
|
||||||
|
Tracker.back!(b)
|
||||||
|
@test Tracker.grad(b) == 1
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user