diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 758aa0a9..a0399411 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -88,6 +88,14 @@ function Base.show(io::IO, l::Dense) print(io, ")") end +# Try to avoid hitting generic matmul in some simple cases +# Base's matmul is so slow that it's worth the extra conversion to hit BLAS +(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ Diagonal(in::Integer) @@ -117,10 +125,3 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end -# Try to avoid hitting generic matmul in some simple cases -# Base's matmul is so slow that it's worth the extra conversion to hit BLAS -(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - invoke(a, Tuple{AbstractArray}, x) - -(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - a(T.(x))