cleaner lstm
This commit is contained in:
parent
cf58748680
commit
bd5822fd71
@ -36,29 +36,27 @@ end
|
||||
|
||||
# LSTM
|
||||
|
||||
struct LSTMCell{M}
|
||||
Wxf::M; Wyf::M; bf::M
|
||||
Wxi::M; Wyi::M; bi::M
|
||||
Wxo::M; Wyo::M; bo::M
|
||||
Wxc::M; Wyc::M; bc::M
|
||||
hidden::M; cell::M
|
||||
struct LSTMCell{D1,D2,V}
|
||||
forget::D1
|
||||
input::D1
|
||||
output::D1
|
||||
cell::D2
|
||||
h::V; c::V
|
||||
end
|
||||
|
||||
LSTMCell(in, out; init = initn) =
|
||||
LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))...,
|
||||
track(zeros(out, 1)), track(zeros(out, 1)))
|
||||
LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, init = initn),
|
||||
track(zeros(out)), track(zeros(out)))
|
||||
|
||||
function (m::LSTMCell)(h_, x)
|
||||
h, c = h_
|
||||
# Gates
|
||||
forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf )
|
||||
input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi )
|
||||
output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo )
|
||||
# State update and output
|
||||
c′ = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc )
|
||||
c = forget .* c .+ input .* c′
|
||||
x′ = [x; h]
|
||||
forget, input, output, cell =
|
||||
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
||||
c = forget .* c .+ input .* cell
|
||||
h = output .* tanh.(c)
|
||||
return (h, c), h
|
||||
end
|
||||
|
||||
hidden(m::LSTMCell) = (m.hidden, m.cell)
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
Loading…
Reference in New Issue
Block a user