zeros replaced by fill to avoid nothing grad
This commit is contained in:
parent
8292cfd81f
commit
812541f8d6
@ -69,7 +69,7 @@ end
|
|||||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||||
init = glorot_uniform) =
|
init = glorot_uniform) =
|
||||||
RNNCell(σ, init(out, in), init(out, out),
|
RNNCell(σ, init(out, in), init(out, out),
|
||||||
init(out), zeros(out))
|
init(out), fill(Float32(0), out))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||||
@ -108,7 +108,7 @@ end
|
|||||||
function LSTMCell(in::Integer, out::Integer;
|
function LSTMCell(in::Integer, out::Integer;
|
||||||
init = glorot_uniform)
|
init = glorot_uniform)
|
||||||
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
||||||
zeros(out), zeros(out))
|
fill(Float32(0), out), fill(Float32(0), out))
|
||||||
cell.b[gate(out, 2)] .= 1
|
cell.b[gate(out, 2)] .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
@ -154,7 +154,7 @@ end
|
|||||||
|
|
||||||
GRUCell(in, out; init = glorot_uniform) =
|
GRUCell(in, out; init = glorot_uniform) =
|
||||||
GRUCell(init(out * 3, in), init(out * 3, out),
|
GRUCell(init(out * 3, in), init(out * 3, out),
|
||||||
init(out * 3), zeros(out))
|
init(out * 3), fill(Float32(0), out))
|
||||||
|
|
||||||
function (m::GRUCell)(h, x)
|
function (m::GRUCell)(h, x)
|
||||||
b, o = m.b, size(h, 1)
|
b, o = m.b, size(h, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user