pool gradients

This commit is contained in:
Mike J Innes 2017-12-15 02:29:14 +00:00
parent d949b31aa5
commit 0bf22dfb8e
3 changed files with 16 additions and 5 deletions

View File

@ -12,16 +12,17 @@ function scan(x::TrackedArray)
return
end
back(c::Call, Δ) = back(c.func, Δ, c.args...)
back(::Call{Void}, Δ) = nothing
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back(x::TrackedArray, Δ)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
ref == 0 && back(x.f, x.grad)
ref == 0 && back_(x.f, x.data, x.grad)
else
ref == 0 && back(x.f, Δ)
ref == 0 && back_(x.f, x.data, Δ)
end
return
end

View File

@ -124,7 +124,7 @@ end
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, conv2d
import NNlib: softmax, ∇softmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
@ -139,6 +139,14 @@ function back(::typeof(conv2d), Δ, x, w)
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ))
end
_pool(x, k, mode) = pool(x, window = k, mode = mode)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) =
TrackedArray(Call(_pool, x, window, mode))
back_(::typeof(_pool), y, Δ, x, k, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode))
# Broadcasting
using ForwardDiff: Dual, partials

View File

@ -47,5 +47,7 @@ end
end
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
end #testset