diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 80132854..c067a302 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -47,10 +47,13 @@ struct LSTMCell{D1,D2,V} h::V; c::V end -LSTMCell(in, out; init = initn) = - LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]..., - Dense(in+out, out, tanh, init = initn), - track(initn(out)), track(initn(out))) +function LSTMCell(in, out; init = initn) + cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]..., + Dense(in+out, out, tanh, init = initn), + track(initn(out)), track(initn(out))) + cell.forget.b.x .= 1 + return cell +end function (m::LSTMCell)(h_, x) h, c = h_