From 5cc681317a6e432199cc7ac6ea5bf807ea955ea1 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Tue, 20 Mar 2018 01:12:04 +0530 Subject: [PATCH] added stride for pooling in tracker --- src/tracker/array.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 35261abe..5bffa7a1 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -261,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad) @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad)) end -_maxpool(x, k, pad) = maxpool(x, k; pad = pad) +_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) -maxpool(x::TrackedArray, k; pad = map(_->0,k)) = - track(_maxpool, x, k, pad) +maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) = + track(_maxpool, x, k, pad, stride) -back_(::typeof(_maxpool), y, Δ, x, k, pad) = - back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad)) +back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) = + back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride)) -_meanpool(x, k, pad) = meanpool(x, k; pad = pad) +_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride) -meanpool(x::TrackedArray, k; pad = map(_->0,k)) = - track(_meanpool, x, k, pad) +meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) = + track(_meanpool, x, k, pad, stride) -back_(::typeof(_meanpool), y, Δ, x, k, pad) = - back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad)) +back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) = + back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride)) # Broadcasting