Flux.jl/src/layers/recurrent.jl

65 lines
1.2 KiB
Julia
Raw Normal View History

2017-09-03 06:12:44 +00:00
# Stateful recurrence
mutable struct Recur{T}
cell::T
state
end
Recur(m) = Recur(m, hidden(m))
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end
# Vanilla RNN
struct RNNCell{D,V}
d::D
h::V
end
RNNCell(in::Integer, out::Integer, init = initn) =
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
function (m::RNNCell)(h, x)
h = m.d([x; h])
return h, h
end
hidden(m::RNNCell) = m.h
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
# LSTM
struct LSTMCell{M}
Wxf::M; Wyf::M; bf::M
Wxi::M; Wyi::M; bi::M
Wxo::M; Wyo::M; bo::M
Wxc::M; Wyc::M; bc::M
hidden::M; cell::M
end
LSTMCell(in, out; init = initn) =
LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))...,
track(zeros(out, 1)), track(zeros(out, 1)))
function (m::LSTMCell)(h_, x)
h, c = h_
# Gates
forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf )
input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi )
output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo )
# State update and output
c = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc )
c = forget .* c .+ input .* c
h = output .* tanh.(c)
return (h, c), h
end
hidden(m::LSTMCell) = (m.hidden, m.cell)