Fix error while backpropagatio
This commit is contained in:
parent
1d93fb8e59
commit
5d7ee884b8
|
@ -339,7 +339,7 @@ depthwiseconv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1,
|
|||
|
||||
function back(::typeof(_depthwiseconv), Δ, x, w, stride, pad)
|
||||
@back(x, NNlib.∇depthwiseconv_data(Δ, data(x), data(w), stride = stride, pad = pad))
|
||||
@back(x, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
|
||||
@back(w, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
|
||||
end
|
||||
|
||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||
|
|
|
@ -169,7 +169,7 @@ end
|
|||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 1,3))
|
||||
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 1, 3))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
|
Loading…
Reference in New Issue