cleaner lstm

This commit is contained in:
Mike J Innes 2017-09-03 02:24:47 -04:00
parent cf58748680
commit bd5822fd71

View File

@ -36,29 +36,27 @@ end
# LSTM # LSTM
struct LSTMCell{M} struct LSTMCell{D1,D2,V}
Wxf::M; Wyf::M; bf::M forget::D1
Wxi::M; Wyi::M; bi::M input::D1
Wxo::M; Wyo::M; bo::M output::D1
Wxc::M; Wyc::M; bc::M cell::D2
hidden::M; cell::M h::V; c::V
end end
LSTMCell(in, out; init = initn) = LSTMCell(in, out; init = initn) =
LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))..., LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
track(zeros(out, 1)), track(zeros(out, 1))) Dense(in+out, out, tanh, init = initn),
track(zeros(out)), track(zeros(out)))
function (m::LSTMCell)(h_, x) function (m::LSTMCell)(h_, x)
h, c = h_ h, c = h_
# Gates x = [x; h]
forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf ) forget, input, output, cell =
input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi ) m.forget(x), m.input(x), m.output(x), m.cell(x)
output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo ) c = forget .* c .+ input .* cell
# State update and output
c = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc )
c = forget .* c .+ input .* c
h = output .* tanh.(c) h = output .* tanh.(c)
return (h, c), h return (h, c), h
end end
hidden(m::LSTMCell) = (m.hidden, m.cell) hidden(m::LSTMCell) = (m.h, m.c)