float32 param initialisers

This commit is contained in:
Mike J Innes 2018-09-07 01:25:32 +01:00
parent 30486f9c03
commit 903db70673
3 changed files with 18 additions and 4 deletions

View File

@ -114,3 +114,11 @@ end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

View File

@ -1,6 +1,12 @@
# Arrays
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
ones(T::Type, dims...) = Base.ones(T, dims...)
zeros(T::Type, dims...) = Base.zeros(T, dims...)
ones(dims...) = Base.ones(Float32, dims...)
zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))

View File

@ -2,7 +2,7 @@ using Flux, Test
using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(10, 10, 3, 2)
x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
mp = MeanPool((2, 2))
@ -10,7 +10,7 @@ using Flux: maxpool, meanpool
end
@testset "CNN" begin
r = zeros(28, 28, 1, 5)
r = zeros(Float32, 28, 28, 1, 5)
m = Chain(
Conv((2, 2), 1=>16, relu),
MaxPool((2,2)),