diff --git a/src/Flux.jl b/src/Flux.jl index da040aa0..7b1fd800 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,7 +8,7 @@ using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, DepthwiseConv, Dropout, LayerNorm, BatchNorm, - params, mapleaves, cpu, gpu + params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/treelike.jl b/src/treelike.jl index bc1a5c0b..3d2b9185 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -1,4 +1,4 @@ -import Adapt: adapt +import Adapt: adapt, adapt_storage import .Tracker: IdSet children(x) = () @@ -64,3 +64,22 @@ gpu_adaptor = identity end gpu(x) = mapleaves(gpu_adaptor, x) + +# Precision + +adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) + +paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m) + +f32(m) = paramtype(Float32, m) +f64(m) = paramtype(Float64, m) + +# General parameter map + +function mapparams(f, m) + mapleaves(m) do x + Tracker.istracked(x) ? param(f(Tracker.data(x))) : + x isa Union{AbstractArray,Number} ? f(x) : + x + end +end diff --git a/test/utils.jl b/test/utils.jl index af0d50fe..c60645c6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -86,3 +86,14 @@ end m = RNN(10, 5) @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] end + +@testset "Precision" begin + m = Chain(Dense(10, 5, relu), Dense(5, 2)) + x = rand(10) + @test eltype(m[1].W.data) == Float32 + @test eltype(m(x).data) == Float32 + @test eltype(f64(m)(x).data) == Float64 + @test eltype(f64(m)[1].W.data) == Float64 + @test eltype(f32(f64(m))[1].W.data) == Float32 + @test Tracker.isleaf(f32(f64(m))[1].W) +end