From c59b820bed91a214126f5d3b66461ee44c855be0 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Mon, 4 Dec 2017 23:47:03 -0800 Subject: [PATCH] Add glorot (Xavier) initialization Set default `Dense` and `RNN` inits to `glorot_uniform()` for `W`, `zeros` for `b`. --- src/layers/basic.jl | 6 ++++-- src/layers/recurrent.jl | 12 ++++++------ src/utils.jl | 2 ++ test/utils.jl | 25 ++++++++++++++++++++++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index aa101c43..9f458ab4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -63,8 +63,10 @@ struct Dense{F,S,T} b::T end -Dense(in::Integer, out::Integer, σ = identity; init = initn) = - Dense(σ, param(init(out, in)), param(init(out))) +function Dense(in::Integer, out::Integer, σ = identity; + initW = glorot_uniform, initb = zeros) + return Dense(σ, param(initW(out, in)), param(initb(out))) +end treelike(Dense) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 599776ce..781bd405 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -79,8 +79,8 @@ struct RNNCell{D,V} h::V end -RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) = - RNNCell(Dense(in+out, out, σ, init = init), param(init(out))) +RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) = + RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out))) function (m::RNNCell)(h, x) h = m.d(combine(x, h)) @@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V} h::V; c::V end -function LSTMCell(in, out; init = initn) - cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]..., - Dense(in+out, out, tanh, init = init), - param(init(out)), param(init(out))) +function LSTMCell(in, out; initW = glorot_uniform, initb = zeros) + cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]..., + Dense(in+out, out, tanh, initW = initW, initb = initb), + param(initW(out)), param(initW(out))) cell.forget.b.data .= 1 return cell end diff --git a/src/utils.jl b/src/utils.jl index f822c111..944d35bf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,8 @@ # Arrays initn(dims...) = randn(dims...)/100 +glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims))) +glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims))) flatten(xs) = reshape(xs, size(xs, 1), :) diff --git a/test/utils.jl b/test/utils.jl index 7638fd2a..1c313a3d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using Flux: throttle +using Flux: throttle, initn, glorot_uniform, glorot_normal @testset "Throttle" begin @testset "default behaviour" begin @@ -47,3 +47,26 @@ using Flux: throttle @test a == [1, 3] end end + +@testset "Initialization" begin + # Set random seed so that these tests don't fail randomly + srand(0) + # initn() should yield a kernel with stddev ~= 1e-2 + v = initn(10, 10) + @test std(v) > 0.9*1e-2 + @test std(v) < 1.1*1e-2 + + # glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)), + # and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = glorot_uniform(n_in, n_out) + @test minimum(v) > -1.1*sqrt(6/(n_in + n_out)) + @test minimum(v) < -0.9*sqrt(6/(n_in + n_out)) + @test maximum(v) > 0.9*sqrt(6/(n_in + n_out)) + @test maximum(v) < 1.1*sqrt(6/(n_in + n_out)) + + v = glorot_normal(n_in, n_out) + @test std(v) > 0.9*sqrt(2/(n_in + n_out)) + @test std(v) < 1.1*sqrt(2/(n_in + n_out)) + end +end \ No newline at end of file