diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 99fc16f2..25629895 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -83,12 +83,12 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0) where {T,N} = DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...) -DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn, +DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform, stride = 1, pad = 0) where N = DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ, stride = stride, pad = pad) -DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, +DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, stride::NTuple{N,Integer} = map(_->1,k), pad::NTuple{N,Integer} = map(_->0,k)) where N = DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,