From 79de829fdc272f81adda4cc725288ddb430c3255 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 27 Feb 2019 11:46:20 +0000 Subject: [PATCH] move Dense's overloads to be near its defn --- src/layers/basic.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 758aa0a9..a0399411 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -88,6 +88,14 @@ function Base.show(io::IO, l::Dense) print(io, ")") end +# Try to avoid hitting generic matmul in some simple cases +# Base's matmul is so slow that it's worth the extra conversion to hit BLAS +(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ Diagonal(in::Integer) @@ -117,10 +125,3 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end -# Try to avoid hitting generic matmul in some simple cases -# Base's matmul is so slow that it's worth the extra conversion to hit BLAS -(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - invoke(a, Tuple{AbstractArray}, x) - -(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - a(T.(x))