Merge pull request #572 from FluxML/precision
Numeric precision utilities
This commit is contained in:
commit
962ce88c0d
@ -8,7 +8,7 @@ using MacroTools: @forward
|
|||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
||||||
params, mapleaves, cpu, gpu
|
params, mapleaves, cpu, gpu, f32, f64
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import Adapt: adapt
|
import Adapt: adapt, adapt_storage
|
||||||
import .Tracker: IdSet
|
import .Tracker: IdSet
|
||||||
|
|
||||||
children(x) = ()
|
children(x) = ()
|
||||||
@ -64,3 +64,22 @@ gpu_adaptor = identity
|
|||||||
end
|
end
|
||||||
|
|
||||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
gpu(x) = mapleaves(gpu_adaptor, x)
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
|
||||||
|
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
||||||
|
|
||||||
|
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
||||||
|
|
||||||
|
f32(m) = paramtype(Float32, m)
|
||||||
|
f64(m) = paramtype(Float64, m)
|
||||||
|
|
||||||
|
# General parameter map
|
||||||
|
|
||||||
|
function mapparams(f, m)
|
||||||
|
mapleaves(m) do x
|
||||||
|
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
||||||
|
x isa Union{AbstractArray,Number} ? f(x) :
|
||||||
|
x
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -86,3 +86,14 @@ end
|
|||||||
m = RNN(10, 5)
|
m = RNN(10, 5)
|
||||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Precision" begin
|
||||||
|
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||||
|
x = rand(10)
|
||||||
|
@test eltype(m[1].W.data) == Float32
|
||||||
|
@test eltype(m(x).data) == Float32
|
||||||
|
@test eltype(f64(m)(x).data) == Float64
|
||||||
|
@test eltype(f64(m)[1].W.data) == Float64
|
||||||
|
@test eltype(f32(f64(m))[1].W.data) == Float32
|
||||||
|
@test Tracker.isleaf(f32(f64(m))[1].W)
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user