From 903db70673daa7d079f40c09667f8317a910f3d0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 7 Sep 2018 01:25:32 +0100 Subject: [PATCH 1/2] float32 param initialisers --- src/layers/basic.jl | 8 ++++++++ src/utils.jl | 10 ++++++++-- test/layers/conv.jl | 4 ++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0c2d3715..48d51d53 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -114,3 +114,11 @@ end function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end + +# Try to avoid hitting generic matmul in some simple cases +# Base's matmul is so slow that it's worth the extra conversion to hit BLAS +(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) diff --git a/src/utils.jl b/src/utils.jl index 1a585e60..9bad3760 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,12 @@ # Arrays -glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims))) -glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims)) +glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims)) +glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims)) + +ones(T::Type, dims...) = Base.ones(T, dims...) +zeros(T::Type, dims...) = Base.zeros(T, dims...) + +ones(dims...) = Base.ones(Float32, dims...) +zeros(dims...) = Base.zeros(Float32, dims...) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 5928bd75..160b7fbb 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -2,7 +2,7 @@ using Flux, Test using Flux: maxpool, meanpool @testset "Pooling" begin - x = randn(10, 10, 3, 2) + x = randn(Float32, 10, 10, 3, 2) mp = MaxPool((2, 2)) @test mp(x) == maxpool(x, (2,2)) mp = MeanPool((2, 2)) @@ -10,7 +10,7 @@ using Flux: maxpool, meanpool end @testset "CNN" begin - r = zeros(28, 28, 1, 5) + r = zeros(Float32, 28, 28, 1, 5) m = Chain( Conv((2, 2), 1=>16, relu), MaxPool((2,2)), From 75ecc0b6badd73132ec534ce4acb050d07604d9a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 12 Nov 2018 20:21:27 +0000 Subject: [PATCH 2/2] downconversion for conv --- src/layers/conv.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5392dffc..99fc16f2 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -37,7 +37,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; @treelike Conv -function (c::Conv)(x) +function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) @@ -51,6 +51,12 @@ function Base.show(io::IO, l::Conv) print(io, ")") end +(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ DepthwiseConv(size, in) DepthwiseConv(size, in=>mul)