From 26550dacdad123f066931a6fd3c71b53592bd09f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sat, 2 Feb 2019 20:01:28 +0000 Subject: [PATCH] Default to zero'ed initial state --- src/layers/recurrent.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 40cd322a..4e23e9ee 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -84,7 +84,7 @@ end RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = RNNCell(σ, param(init(out, in)), param(init(out, out)), - param(zeros(out)), param(init(out))) + param(init(out)), param(zeros(out))) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -122,8 +122,8 @@ end function LSTMCell(in::Integer, out::Integer; init = glorot_uniform) - cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)), - param(init(out)), param(init(out))) + cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)), + param(zeros(out)), param(zeros(out))) cell.b.data[gate(out, 2)] .= 1 return cell end @@ -169,7 +169,7 @@ end GRUCell(in, out; init = glorot_uniform) = GRUCell(param(init(out*3, in)), param(init(out*3, out)), - param(zeros(out*3)), param(init(out))) + param(init(out*3)), param(zeros(out))) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1)