pool gradients
This commit is contained in:
parent
d949b31aa5
commit
0bf22dfb8e
|
@ -12,16 +12,17 @@ function scan(x::TrackedArray)
|
|||
return
|
||||
end
|
||||
|
||||
back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
||||
back(::Call{Void}, Δ) = nothing
|
||||
back_(f, y, args...) = back(f, args...)
|
||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||
back_(::Call{Void}, y, Δ) = nothing
|
||||
|
||||
function back(x::TrackedArray, Δ)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad .+= Δ
|
||||
ref == 0 && back(x.f, x.grad)
|
||||
ref == 0 && back_(x.f, x.data, x.grad)
|
||||
else
|
||||
ref == 0 && back(x.f, Δ)
|
||||
ref == 0 && back_(x.f, x.data, Δ)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
|
|
@ -124,7 +124,7 @@ end
|
|||
# NNlib
|
||||
|
||||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, conv2d
|
||||
import NNlib: softmax, ∇softmax, conv2d, pool
|
||||
|
||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||
|
||||
|
@ -139,6 +139,14 @@ function back(::typeof(conv2d), Δ, x, w)
|
|||
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ))
|
||||
end
|
||||
|
||||
_pool(x, k, mode) = pool(x, window = k, mode = mode)
|
||||
|
||||
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) =
|
||||
TrackedArray(Call(_pool, x, window, mode))
|
||||
|
||||
back_(::typeof(_pool), y, Δ, x, k, mode) =
|
||||
back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode))
|
||||
|
||||
# Broadcasting
|
||||
|
||||
using ForwardDiff: Dual, partials
|
||||
|
|
|
@ -47,5 +47,7 @@ end
|
|||
end
|
||||
|
||||
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue