diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 216957a7..99a04890 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -165,6 +165,12 @@ function Base.show(io::IO, l::DepthwiseConv) print(io, ")") end +(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ MaxPool(k)