diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index db2d3fc5..ee9cad00 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -36,29 +36,27 @@ end # LSTM -struct LSTMCell{M} - Wxf::M; Wyf::M; bf::M - Wxi::M; Wyi::M; bi::M - Wxo::M; Wyo::M; bo::M - Wxc::M; Wyc::M; bc::M - hidden::M; cell::M +struct LSTMCell{D1,D2,V} + forget::D1 + input::D1 + output::D1 + cell::D2 + h::V; c::V end LSTMCell(in, out; init = initn) = - LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))..., - track(zeros(out, 1)), track(zeros(out, 1))) + LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]..., + Dense(in+out, out, tanh, init = initn), + track(zeros(out)), track(zeros(out))) function (m::LSTMCell)(h_, x) h, c = h_ - # Gates - forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf ) - input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi ) - output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo ) - # State update and output - c′ = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc ) - c = forget .* c .+ input .* c′ + x′ = [x; h] + forget, input, output, cell = + m.forget(x′), m.input(x′), m.output(x′), m.cell(x′) + c = forget .* c .+ input .* cell h = output .* tanh.(c) return (h, c), h end -hidden(m::LSTMCell) = (m.hidden, m.cell) +hidden(m::LSTMCell) = (m.h, m.c)