diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5392dffc..99fc16f2 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -37,7 +37,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; @treelike Conv -function (c::Conv)(x) +function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) @@ -51,6 +51,12 @@ function Base.show(io::IO, l::Conv) print(io, ")") end +(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ DepthwiseConv(size, in) DepthwiseConv(size, in=>mul)