diff --git a/src/layers/conv.jl b/src/layers/conv.jl index b586915a..f2b9c5d7 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -73,8 +73,8 @@ struct DepthwiseConv{N,F,A,V} pad::NTuple{N,Int} end -DepthwiseConv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0) where T = +DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0) where {T,N} = DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...) DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,