numeric precision utilities

This commit is contained in:
Mike J Innes 2019-01-25 10:06:37 +00:00
parent 1cf37ab9eb
commit 791939709b
3 changed files with 22 additions and 2 deletions

View File

@ -8,7 +8,7 @@ using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib

View File

@ -1,4 +1,4 @@
import Adapt: adapt
import Adapt: adapt, adapt_storage
import .Tracker: IdSet
children(x) = ()
@ -64,3 +64,12 @@ gpu_adaptor = identity
end
gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)

View File

@ -86,3 +86,14 @@ end
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
@test eltype(m[1].W.data) == Float32
@test eltype(m(x).data) == Float32
@test eltype(f64(m)(x).data) == Float64
@test eltype(f64(m)[1].W.data) == Float64
@test eltype(f32(f64(m))[1].W.data) == Float32
@test Tracker.isleaf(f32(f64(m))[1].W)
end