diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0c2d3715..48d51d53 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -114,3 +114,11 @@ end function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end + +# Try to avoid hitting generic matmul in some simple cases +# Base's matmul is so slow that it's worth the extra conversion to hit BLAS +(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5392dffc..99fc16f2 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -37,7 +37,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; @treelike Conv -function (c::Conv)(x) +function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) @@ -51,6 +51,12 @@ function Base.show(io::IO, l::Conv) print(io, ")") end +(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ DepthwiseConv(size, in) DepthwiseConv(size, in=>mul) diff --git a/src/utils.jl b/src/utils.jl index 1a585e60..9bad3760 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,12 @@ # Arrays -glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims))) -glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims)) +glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims)) +glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims)) + +ones(T::Type, dims...) = Base.ones(T, dims...) +zeros(T::Type, dims...) = Base.zeros(T, dims...) + +ones(dims...) = Base.ones(Float32, dims...) +zeros(dims...) = Base.zeros(Float32, dims...) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 5928bd75..160b7fbb 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -2,7 +2,7 @@ using Flux, Test using Flux: maxpool, meanpool @testset "Pooling" begin - x = randn(10, 10, 3, 2) + x = randn(Float32, 10, 10, 3, 2) mp = MaxPool((2, 2)) @test mp(x) == maxpool(x, (2,2)) mp = MeanPool((2, 2)) @@ -10,7 +10,7 @@ using Flux: maxpool, meanpool end @testset "CNN" begin - r = zeros(28, 28, 1, 5) + r = zeros(Float32, 28, 28, 1, 5) m = Chain( Conv((2, 2), 1=>16, relu), MaxPool((2,2)),