diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index b5eea4a4..ddfa6426 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -69,7 +69,7 @@ end RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = RNNCell(σ, init(out, in), init(out, out), - init(out), zeros(out)) + init(out), fill(Float32(0), out)) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -108,7 +108,7 @@ end function LSTMCell(in::Integer, out::Integer; init = glorot_uniform) cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4), - zeros(out), zeros(out)) + fill(Float32(0), out), fill(Float32(0), out)) cell.b[gate(out, 2)] .= 1 return cell end @@ -154,7 +154,7 @@ end GRUCell(in, out; init = glorot_uniform) = GRUCell(init(out * 3, in), init(out * 3, out), - init(out * 3), zeros(out)) + init(out * 3), fill(Float32(0), out)) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1)