pool padding

This commit is contained in:
Mike J Innes 2017-12-18 18:18:14 +00:00
parent 6b6974e14a
commit 98b362729d
1 changed files with 5 additions and 5 deletions

View File

@ -151,13 +151,13 @@ function back(::typeof(_conv2d), Δ, x, w, stride, pad)
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
end
_pool(x, k, mode) = pool(x, window = k, mode = mode)
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) =
TrackedArray(Call(_pool, x, window, mode))
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
TrackedArray(Call(_pool, x, window, padding, mode))
back_(::typeof(_pool), y, Δ, x, k, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode))
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
# Broadcasting