numeric precision utilities
This commit is contained in:
parent
1cf37ab9eb
commit
791939709b
|
@ -8,7 +8,7 @@ using MacroTools: @forward
|
|||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import Adapt: adapt
|
||||
import Adapt: adapt, adapt_storage
|
||||
import .Tracker: IdSet
|
||||
|
||||
children(x) = ()
|
||||
|
@ -64,3 +64,12 @@ gpu_adaptor = identity
|
|||
end
|
||||
|
||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
||||
|
||||
# Precision
|
||||
|
||||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
||||
|
||||
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
||||
|
||||
f32(m) = paramtype(Float32, m)
|
||||
f64(m) = paramtype(Float64, m)
|
||||
|
|
|
@ -86,3 +86,14 @@ end
|
|||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
||||
@testset "Precision" begin
|
||||
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||
x = rand(10)
|
||||
@test eltype(m[1].W.data) == Float32
|
||||
@test eltype(m(x).data) == Float32
|
||||
@test eltype(f64(m)(x).data) == Float64
|
||||
@test eltype(f64(m)[1].W.data) == Float64
|
||||
@test eltype(f32(f64(m))[1].W.data) == Float32
|
||||
@test Tracker.isleaf(f32(f64(m))[1].W)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue