cleaner lstm
This commit is contained in:
parent
cf58748680
commit
bd5822fd71
@ -36,29 +36,27 @@ end
|
|||||||
|
|
||||||
# LSTM
|
# LSTM
|
||||||
|
|
||||||
struct LSTMCell{M}
|
struct LSTMCell{D1,D2,V}
|
||||||
Wxf::M; Wyf::M; bf::M
|
forget::D1
|
||||||
Wxi::M; Wyi::M; bi::M
|
input::D1
|
||||||
Wxo::M; Wyo::M; bo::M
|
output::D1
|
||||||
Wxc::M; Wyc::M; bc::M
|
cell::D2
|
||||||
hidden::M; cell::M
|
h::V; c::V
|
||||||
end
|
end
|
||||||
|
|
||||||
LSTMCell(in, out; init = initn) =
|
LSTMCell(in, out; init = initn) =
|
||||||
LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))...,
|
LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
||||||
track(zeros(out, 1)), track(zeros(out, 1)))
|
Dense(in+out, out, tanh, init = initn),
|
||||||
|
track(zeros(out)), track(zeros(out)))
|
||||||
|
|
||||||
function (m::LSTMCell)(h_, x)
|
function (m::LSTMCell)(h_, x)
|
||||||
h, c = h_
|
h, c = h_
|
||||||
# Gates
|
x′ = [x; h]
|
||||||
forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf )
|
forget, input, output, cell =
|
||||||
input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi )
|
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
||||||
output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo )
|
c = forget .* c .+ input .* cell
|
||||||
# State update and output
|
|
||||||
c′ = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc )
|
|
||||||
c = forget .* c .+ input .* c′
|
|
||||||
h = output .* tanh.(c)
|
h = output .* tanh.(c)
|
||||||
return (h, c), h
|
return (h, c), h
|
||||||
end
|
end
|
||||||
|
|
||||||
hidden(m::LSTMCell) = (m.hidden, m.cell)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
Loading…
Reference in New Issue
Block a user