numeric precision utilities

This commit is contained in:
Mike J Innes 2019-01-25 10:06:37 +00:00
parent 1cf37ab9eb
commit 791939709b
3 changed files with 22 additions and 2 deletions

View File

@ -8,7 +8,7 @@ using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm, DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib @reexport using NNlib

View File

@ -1,4 +1,4 @@
import Adapt: adapt import Adapt: adapt, adapt_storage
import .Tracker: IdSet import .Tracker: IdSet
children(x) = () children(x) = ()
@ -64,3 +64,12 @@ gpu_adaptor = identity
end end
gpu(x) = mapleaves(gpu_adaptor, x) gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)

View File

@ -86,3 +86,14 @@ end
m = RNN(10, 5) m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end end
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
@test eltype(m[1].W.data) == Float32
@test eltype(m(x).data) == Float32
@test eltype(f64(m)(x).data) == Float64
@test eltype(f64(m)[1].W.data) == Float64
@test eltype(f32(f64(m))[1].W.data) == Float32
@test Tracker.isleaf(f32(f64(m))[1].W)
end