pool padding
This commit is contained in:
parent
6b6974e14a
commit
98b362729d
|
@ -151,13 +151,13 @@ function back(::typeof(_conv2d), Δ, x, w, stride, pad)
|
|||
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||
end
|
||||
|
||||
_pool(x, k, mode) = pool(x, window = k, mode = mode)
|
||||
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
|
||||
|
||||
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) =
|
||||
TrackedArray(Call(_pool, x, window, mode))
|
||||
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
|
||||
TrackedArray(Call(_pool, x, window, padding, mode))
|
||||
|
||||
back_(::typeof(_pool), y, Δ, x, k, mode) =
|
||||
back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode))
|
||||
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
|
||||
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
|
||||
|
||||
# Broadcasting
|
||||
|
||||
|
|
Loading…
Reference in New Issue