From 791939709bc91aaf91730213bc2da3704744a695 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 25 Jan 2019 10:06:37 +0000 Subject: [PATCH 1/2] numeric precision utilities --- src/Flux.jl | 2 +- src/treelike.jl | 11 ++++++++++- test/utils.jl | 11 +++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index da040aa0..7b1fd800 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,7 +8,7 @@ using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, DepthwiseConv, Dropout, LayerNorm, BatchNorm, - params, mapleaves, cpu, gpu + params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/treelike.jl b/src/treelike.jl index bc1a5c0b..b68c61cf 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -1,4 +1,4 @@ -import Adapt: adapt +import Adapt: adapt, adapt_storage import .Tracker: IdSet children(x) = () @@ -64,3 +64,12 @@ gpu_adaptor = identity end gpu(x) = mapleaves(gpu_adaptor, x) + +# Precision + +adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) + +paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m) + +f32(m) = paramtype(Float32, m) +f64(m) = paramtype(Float64, m) diff --git a/test/utils.jl b/test/utils.jl index af0d50fe..c60645c6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -86,3 +86,14 @@ end m = RNN(10, 5) @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] end + +@testset "Precision" begin + m = Chain(Dense(10, 5, relu), Dense(5, 2)) + x = rand(10) + @test eltype(m[1].W.data) == Float32 + @test eltype(m(x).data) == Float32 + @test eltype(f64(m)(x).data) == Float64 + @test eltype(f64(m)[1].W.data) == Float64 + @test eltype(f32(f64(m))[1].W.data) == Float32 + @test Tracker.isleaf(f32(f64(m))[1].W) +end From 2b1a3e92da3c0795b2af81910123ac8cd6c57879 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 25 Jan 2019 10:11:46 +0000 Subject: [PATCH 2/2] mapparams --- src/treelike.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/treelike.jl b/src/treelike.jl index b68c61cf..3d2b9185 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -73,3 +73,13 @@ paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m) f32(m) = paramtype(Float32, m) f64(m) = paramtype(Float64, m) + +# General parameter map + +function mapparams(f, m) + mapleaves(m) do x + Tracker.istracked(x) ? param(f(Tracker.data(x))) : + x isa Union{AbstractArray,Number} ? f(x) : + x + end +end