Merge pull request #120 from staticfloat/sf/dense_initialization
Better default initialization for Dense layers
This commit is contained in:
commit
27d896943e
@ -63,8 +63,10 @@ struct Dense{F,S,T}
|
|||||||
b::T
|
b::T
|
||||||
end
|
end
|
||||||
|
|
||||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
function Dense(in::Integer, out::Integer, σ = identity;
|
||||||
Dense(σ, param(init(out, in)), param(init(out)))
|
initW = glorot_uniform, initb = zeros)
|
||||||
|
return Dense(σ, param(initW(out, in)), param(initb(out)))
|
||||||
|
end
|
||||||
|
|
||||||
treelike(Dense)
|
treelike(Dense)
|
||||||
|
|
||||||
|
@ -79,8 +79,8 @@ struct RNNCell{D,V}
|
|||||||
h::V
|
h::V
|
||||||
end
|
end
|
||||||
|
|
||||||
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) =
|
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
||||||
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
|
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
h = m.d(combine(x, h))
|
h = m.d(combine(x, h))
|
||||||
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
|
|||||||
h::V; c::V
|
h::V; c::V
|
||||||
end
|
end
|
||||||
|
|
||||||
function LSTMCell(in, out; init = initn)
|
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
||||||
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]...,
|
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
||||||
Dense(in+out, out, tanh, init = init),
|
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
||||||
param(init(out)), param(init(out)))
|
param(initW(out)), param(initW(out)))
|
||||||
cell.forget.b.data .= 1
|
cell.forget.b.data .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Arrays
|
# Arrays
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
|
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||||
|
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||||
|
|
||||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
flatten(xs) = reshape(xs, size(xs, 1), :)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using Flux: throttle
|
using Flux: throttle, initn, glorot_uniform, glorot_normal
|
||||||
|
|
||||||
@testset "Throttle" begin
|
@testset "Throttle" begin
|
||||||
@testset "default behaviour" begin
|
@testset "default behaviour" begin
|
||||||
@ -47,3 +47,26 @@ using Flux: throttle
|
|||||||
@test a == [1, 3]
|
@test a == [1, 3]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Initialization" begin
|
||||||
|
# Set random seed so that these tests don't fail randomly
|
||||||
|
srand(0)
|
||||||
|
# initn() should yield a kernel with stddev ~= 1e-2
|
||||||
|
v = initn(10, 10)
|
||||||
|
@test std(v) > 0.9*1e-2
|
||||||
|
@test std(v) < 1.1*1e-2
|
||||||
|
|
||||||
|
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||||
|
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||||
|
for (n_in, n_out) in [(100, 100), (100, 400)]
|
||||||
|
v = glorot_uniform(n_in, n_out)
|
||||||
|
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
||||||
|
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
||||||
|
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
||||||
|
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
||||||
|
|
||||||
|
v = glorot_normal(n_in, n_out)
|
||||||
|
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||||
|
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user