cleaner lstm

This commit is contained in:
Mike J Innes 2017-09-03 02:24:47 -04:00
parent cf58748680
commit bd5822fd71

View File

@ -36,29 +36,27 @@ end
# LSTM
struct LSTMCell{M}
Wxf::M; Wyf::M; bf::M
Wxi::M; Wyi::M; bi::M
Wxo::M; Wyo::M; bo::M
Wxc::M; Wyc::M; bc::M
hidden::M; cell::M
struct LSTMCell{D1,D2,V}
forget::D1
input::D1
output::D1
cell::D2
h::V; c::V
end
LSTMCell(in, out; init = initn) =
LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))...,
track(zeros(out, 1)), track(zeros(out, 1)))
LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
Dense(in+out, out, tanh, init = initn),
track(zeros(out)), track(zeros(out)))
function (m::LSTMCell)(h_, x)
h, c = h_
# Gates
forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf )
input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi )
output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo )
# State update and output
c = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc )
c = forget .* c .+ input .* c
x = [x; h]
forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x)
c = forget .* c .+ input .* cell
h = output .* tanh.(c)
return (h, c), h
end
hidden(m::LSTMCell) = (m.hidden, m.cell)
hidden(m::LSTMCell) = (m.h, m.c)