Add glorot (Xavier) initialization
Set default `Dense` and `RNN` inits to `glorot_uniform()` for `W`, `zeros` for `b`.
This commit is contained in:
parent
cab235a578
commit
c59b820bed
|
@ -63,8 +63,10 @@ struct Dense{F,S,T}
|
|||
b::T
|
||||
end
|
||||
|
||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||
Dense(σ, param(init(out, in)), param(init(out)))
|
||||
function Dense(in::Integer, out::Integer, σ = identity;
|
||||
initW = glorot_uniform, initb = zeros)
|
||||
return Dense(σ, param(initW(out, in)), param(initb(out)))
|
||||
end
|
||||
|
||||
treelike(Dense)
|
||||
|
||||
|
|
|
@ -79,8 +79,8 @@ struct RNNCell{D,V}
|
|||
h::V
|
||||
end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) =
|
||||
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
||||
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
h = m.d(combine(x, h))
|
||||
|
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
|
|||
h::V; c::V
|
||||
end
|
||||
|
||||
function LSTMCell(in, out; init = initn)
|
||||
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, init = init),
|
||||
param(init(out)), param(init(out)))
|
||||
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
||||
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
||||
param(initW(out)), param(initW(out)))
|
||||
cell.forget.b.data .= 1
|
||||
return cell
|
||||
end
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||
|
||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux: throttle
|
||||
using Flux: throttle, initn, glorot_uniform, glorot_normal
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
|
@ -47,3 +47,26 @@ using Flux: throttle
|
|||
@test a == [1, 3]
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
srand(0)
|
||||
# initn() should yield a kernel with stddev ~= 1e-2
|
||||
v = initn(10, 10)
|
||||
@test std(v) > 0.9*1e-2
|
||||
@test std(v) < 1.1*1e-2
|
||||
|
||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||
for (n_in, n_out) in [(100, 100), (100, 400)]
|
||||
v = glorot_uniform(n_in, n_out)
|
||||
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
||||
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
||||
|
||||
v = glorot_normal(n_in, n_out)
|
||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue