Add glorot (Xavier) initialization

Set default `Dense` and `RNN` inits to `glorot_uniform()` for `W`, `zeros` for `b`.
This commit is contained in:
Elliot Saba 2017-12-04 23:47:03 -08:00
parent cab235a578
commit c59b820bed
4 changed files with 36 additions and 9 deletions

View File

@ -63,8 +63,10 @@ struct Dense{F,S,T}
b::T
end
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
Dense(σ, param(init(out, in)), param(init(out)))
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
return Dense(σ, param(initW(out, in)), param(initb(out)))
end
treelike(Dense)

View File

@ -79,8 +79,8 @@ struct RNNCell{D,V}
h::V
end
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) =
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
function (m::RNNCell)(h, x)
h = m.d(combine(x, h))
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
h::V; c::V
end
function LSTMCell(in, out; init = initn)
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]...,
Dense(in+out, out, tanh, init = init),
param(init(out)), param(init(out)))
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
Dense(in+out, out, tanh, initW = initW, initb = initb),
param(initW(out)), param(initW(out)))
cell.forget.b.data .= 1
return cell
end

View File

@ -1,6 +1,8 @@
# Arrays
initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
flatten(xs) = reshape(xs, size(xs, 1), :)

View File

@ -1,4 +1,4 @@
using Flux: throttle
using Flux: throttle, initn, glorot_uniform, glorot_normal
@testset "Throttle" begin
@testset "default behaviour" begin
@ -47,3 +47,26 @@ using Flux: throttle
@test a == [1, 3]
end
end
@testset "Initialization" begin
# Set random seed so that these tests don't fail randomly
srand(0)
# initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10)
@test std(v) > 0.9*1e-2
@test std(v) < 1.1*1e-2
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = glorot_uniform(n_in, n_out)
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
v = glorot_normal(n_in, n_out)
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
end
end