added stride for pooling in tracker
This commit is contained in:
parent
0ba5ce4601
commit
5cc681317a
@ -261,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
|
||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
|
||||
end
|
||||
|
||||
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
|
||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
||||
track(_maxpool, x, k, pad)
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_maxpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
|
||||
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
|
||||
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
|
||||
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
||||
track(_meanpool, x, k, pad)
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_meanpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
|
||||
# Broadcasting
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user