added stride for pooling in tracker

This commit is contained in:
tejank10 2018-03-20 01:12:04 +05:30 committed by Mike J Innes
parent 0ba5ce4601
commit 5cc681317a

View File

@ -261,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
end
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_maxpool, x, k, pad)
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_maxpool, x, k, pad, stride)
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_meanpool, x, k, pad)
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_meanpool, x, k, pad, stride)
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
# Broadcasting