Merge pull request #590 from oxinabox/patch-2
Default to zero'ed initial state for all RNN
This commit is contained in:
commit
e774053126
@ -84,7 +84,7 @@ end
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zeros(out)), param(init(out)))
|
||||
param(init(out)), param(zeros(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
@ -122,8 +122,8 @@ end
|
||||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(init(out)), param(init(out)))
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
|
||||
param(zeros(out)), param(zeros(out)))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
@ -169,7 +169,7 @@ end
|
||||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(init(out)))
|
||||
param(init(out*3)), param(zeros(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user