Fix error while backpropagatio

This commit is contained in:
Avik Pal 2018-06-09 13:04:49 +05:30
parent 1d93fb8e59
commit 5d7ee884b8
2 changed files with 2 additions and 2 deletions

View File

@ -339,7 +339,7 @@ depthwiseconv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1,
function back(::typeof(_depthwiseconv), Δ, x, w, stride, pad)
@back(x, NNlib.∇depthwiseconv_data(Δ, data(x), data(w), stride = stride, pad = pad))
@back(x, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
@back(w, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
end
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)

View File

@ -169,7 +169,7 @@ end
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 1,3))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 1, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))