diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 751689f5..c2cc15bf 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -55,6 +55,11 @@ function Conv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = id return Conv(σ, w, b, stride, pad, dilation) end +function Conv(;weight::AbstractArray, bias::Union{Zeros, AbstractVector{T}}, activation = identity, + stride = 1, pad = 0, dilation = 1) where {T,N} + Conv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) +end + """ convfilter(filter::Tuple, in=>out) @@ -144,6 +149,11 @@ function ConvTranspose(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}} return ConvTranspose(σ, w, b, stride, pad, dilation) end +function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, + activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} + ConvTranspose(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) +end + function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1, weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N @@ -233,6 +243,11 @@ function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}} return DepthwiseConv(σ, w, b, stride, pad, dilation) end +function DepthwiseConv(;weight::AbstractArray, bias::Union{Zeros, AbstractVector{T}}, + activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} + DepthwiseConv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) +end + """ depthwiseconvfilter(filter::Tuple, in=>out)