diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 94a10606..41f0e2e3 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -60,8 +60,8 @@ function Conv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = id return Conv(σ, w, b, stride, pad, dilation) end -function Conv(;weight::AbstractArray, bias::Union{Zeros, AbstractVector{T}}, activation = identity, - stride = 1, pad = 0, dilation = 1) where {T,N} +function Conv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, + activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} Conv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) end @@ -268,7 +268,7 @@ function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}} return DepthwiseConv(σ, w, b, stride, pad, dilation) end -function DepthwiseConv(;weight::AbstractArray, bias::Union{Zeros, AbstractVector{T}}, +function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} DepthwiseConv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) end @@ -379,7 +379,7 @@ function CrossCor(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ return CrossCor(σ, w, b, stride, pad, dilation) end -function CrossCor(;weight::AbstractArray, bias::Union{Zeros, AbstractVector{T}}, +function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} CrossCor(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) end