Merge pull request #572 from FluxML/precision

Numeric precision utilities
This commit is contained in:
Mike J Innes 2019-01-25 10:45:13 +00:00 committed by GitHub
commit 962ce88c0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 2 deletions

View File

@ -8,7 +8,7 @@ using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm, DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib @reexport using NNlib

View File

@ -1,4 +1,4 @@
import Adapt: adapt import Adapt: adapt, adapt_storage
import .Tracker: IdSet import .Tracker: IdSet
children(x) = () children(x) = ()
@ -64,3 +64,22 @@ gpu_adaptor = identity
end end
gpu(x) = mapleaves(gpu_adaptor, x) gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
x isa Union{AbstractArray,Number} ? f(x) :
x
end
end

View File

@ -86,3 +86,14 @@ end
m = RNN(10, 5) m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end end
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
@test eltype(m[1].W.data) == Float32
@test eltype(m(x).data) == Float32
@test eltype(f64(m)(x).data) == Float64
@test eltype(f64(m)[1].W.data) == Float64
@test eltype(f32(f64(m))[1].W.data) == Float32
@test Tracker.isleaf(f32(f64(m))[1].W)
end