2018-08-18 15:50:52 +00:00
|
|
|
|
gate(h, n) = (1:h) .+ h*(n-1)
|
2018-02-01 20:57:39 +00:00
|
|
|
|
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
|
|
|
|
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
2017-09-05 06:29:31 +00:00
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# Stateful recurrence
|
|
|
|
|
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
|
|
|
|
Recur(cell)
|
|
|
|
|
|
|
|
|
|
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
|
|
|
|
|
in the background. `cell` should be a model of the form:
|
|
|
|
|
|
|
|
|
|
h, y = cell(h, x...)
|
|
|
|
|
|
|
|
|
|
For example, here's a recurrent network that keeps a running total of its inputs.
|
|
|
|
|
|
2017-10-18 14:44:06 +00:00
|
|
|
|
```julia
|
|
|
|
|
accum(h, x) = (h+x, x)
|
|
|
|
|
rnn = Flux.Recur(accum, 0)
|
|
|
|
|
rnn(2) # 2
|
|
|
|
|
rnn(3) # 3
|
|
|
|
|
rnn.state # 5
|
|
|
|
|
rnn.(1:10) # apply to a sequence
|
|
|
|
|
rnn.state # 60
|
|
|
|
|
```
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
2017-09-03 06:12:44 +00:00
|
|
|
|
mutable struct Recur{T}
|
|
|
|
|
cell::T
|
2017-10-19 16:21:08 +00:00
|
|
|
|
init
|
2017-09-03 06:12:44 +00:00
|
|
|
|
state
|
|
|
|
|
end
|
|
|
|
|
|
2017-10-19 16:21:08 +00:00
|
|
|
|
Recur(m, h = hidden(m)) = Recur(m, h, h)
|
2017-09-03 06:12:44 +00:00
|
|
|
|
|
|
|
|
|
function (m::Recur)(xs...)
|
|
|
|
|
h, y = m.cell(m.state, xs...)
|
|
|
|
|
m.state = h
|
|
|
|
|
return y
|
|
|
|
|
end
|
|
|
|
|
|
2018-07-12 21:43:11 +00:00
|
|
|
|
@treelike Recur cell, init
|
2017-09-07 04:05:02 +00:00
|
|
|
|
|
2017-09-05 23:25:34 +00:00
|
|
|
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
|
|
|
|
|
2017-10-19 16:21:08 +00:00
|
|
|
|
"""
|
|
|
|
|
reset!(rnn)
|
|
|
|
|
|
|
|
|
|
Reset the hidden state of a recurrent layer back to its original value. See also
|
|
|
|
|
`truncate!`.
|
|
|
|
|
|
|
|
|
|
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
|
|
|
|
|
|
|
|
|
rnn.state = hidden(rnn.cell)
|
|
|
|
|
"""
|
2019-09-19 14:22:11 +00:00
|
|
|
|
reset!(m::Recur) = (m.state = m.init)
|
|
|
|
|
reset!(m) = foreach(reset!, functor(m)[1])
|
2017-09-06 18:03:25 +00:00
|
|
|
|
|
2017-10-16 07:53:39 +00:00
|
|
|
|
flip(f, xs) = reverse(f.(reverse(xs)))
|
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# Vanilla RNN
|
|
|
|
|
|
2018-02-02 16:19:56 +00:00
|
|
|
|
mutable struct RNNCell{F,A,V}
|
2018-02-01 20:57:39 +00:00
|
|
|
|
σ::F
|
|
|
|
|
Wi::A
|
|
|
|
|
Wh::A
|
|
|
|
|
b::V
|
2017-09-03 06:12:44 +00:00
|
|
|
|
h::V
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-01 20:57:39 +00:00
|
|
|
|
RNNCell(in::Integer, out::Integer, σ = tanh;
|
|
|
|
|
init = glorot_uniform) =
|
2019-03-08 12:13:58 +00:00
|
|
|
|
RNNCell(σ, init(out, in), init(out, out),
|
2019-07-08 13:52:23 +00:00
|
|
|
|
init(out), zeros(out))
|
2017-09-03 06:12:44 +00:00
|
|
|
|
|
|
|
|
|
function (m::RNNCell)(h, x)
|
2018-02-01 20:57:39 +00:00
|
|
|
|
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
|
|
|
|
h = σ.(Wi*x .+ Wh*h .+ b)
|
2017-09-03 06:12:44 +00:00
|
|
|
|
return h, h
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
hidden(m::RNNCell) = m.h
|
|
|
|
|
|
2018-07-12 21:43:11 +00:00
|
|
|
|
@treelike RNNCell
|
2017-09-07 04:05:02 +00:00
|
|
|
|
|
2018-02-01 20:57:39 +00:00
|
|
|
|
function Base.show(io::IO, l::RNNCell)
|
|
|
|
|
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
|
|
|
|
l.σ == identity || print(io, ", ", l.σ)
|
|
|
|
|
print(io, ")")
|
2017-09-03 06:12:44 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
|
|
|
|
RNN(in::Integer, out::Integer, σ = tanh)
|
|
|
|
|
|
|
|
|
|
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
|
|
|
|
|
output fed back into the input each time step.
|
|
|
|
|
"""
|
2017-09-05 23:25:34 +00:00
|
|
|
|
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# LSTM
|
|
|
|
|
|
2018-02-02 16:19:56 +00:00
|
|
|
|
mutable struct LSTMCell{A,V}
|
2018-02-01 20:57:39 +00:00
|
|
|
|
Wi::A
|
|
|
|
|
Wh::A
|
|
|
|
|
b::V
|
|
|
|
|
h::V
|
|
|
|
|
c::V
|
2017-09-03 06:12:44 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-02-01 20:57:39 +00:00
|
|
|
|
function LSTMCell(in::Integer, out::Integer;
|
|
|
|
|
init = glorot_uniform)
|
2019-03-08 12:13:58 +00:00
|
|
|
|
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
2019-07-08 13:52:23 +00:00
|
|
|
|
zeros(out), zeros(out))
|
2019-07-02 12:08:30 +00:00
|
|
|
|
cell.b[gate(out, 2)] .= 1
|
2017-09-05 06:42:32 +00:00
|
|
|
|
return cell
|
|
|
|
|
end
|
2017-09-03 06:12:44 +00:00
|
|
|
|
|
2018-09-05 14:55:08 +00:00
|
|
|
|
function (m::LSTMCell)((h, c), x)
|
2018-02-08 10:24:59 +00:00
|
|
|
|
b, o = m.b, size(h, 1)
|
2018-02-01 20:57:39 +00:00
|
|
|
|
g = m.Wi*x .+ m.Wh*h .+ b
|
|
|
|
|
input = σ.(gate(g, o, 1))
|
|
|
|
|
forget = σ.(gate(g, o, 2))
|
|
|
|
|
cell = tanh.(gate(g, o, 3))
|
|
|
|
|
output = σ.(gate(g, o, 4))
|
2017-09-03 06:24:47 +00:00
|
|
|
|
c = forget .* c .+ input .* cell
|
2018-02-01 20:57:39 +00:00
|
|
|
|
h′ = output .* tanh.(c)
|
|
|
|
|
return (h′, c), h′
|
2017-09-03 06:12:44 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-09-03 06:24:47 +00:00
|
|
|
|
hidden(m::LSTMCell) = (m.h, m.c)
|
2017-09-05 23:25:34 +00:00
|
|
|
|
|
2018-07-12 21:43:11 +00:00
|
|
|
|
@treelike LSTMCell
|
2017-09-06 22:59:07 +00:00
|
|
|
|
|
2018-02-01 20:57:39 +00:00
|
|
|
|
Base.show(io::IO, l::LSTMCell) =
|
2018-02-22 00:21:48 +00:00
|
|
|
|
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
2017-09-05 23:25:34 +00:00
|
|
|
|
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
2018-10-03 10:45:29 +00:00
|
|
|
|
LSTM(in::Integer, out::Integer)
|
2017-10-18 14:30:05 +00:00
|
|
|
|
|
|
|
|
|
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
|
|
|
|
exhibits a longer memory span over sequences.
|
|
|
|
|
|
2019-04-25 11:04:03 +00:00
|
|
|
|
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
2017-10-18 14:30:05 +00:00
|
|
|
|
for a good overview of the internals.
|
|
|
|
|
"""
|
2017-09-05 23:25:34 +00:00
|
|
|
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
2017-11-24 13:33:06 +00:00
|
|
|
|
|
|
|
|
|
# GRU
|
|
|
|
|
|
2018-02-02 16:19:56 +00:00
|
|
|
|
mutable struct GRUCell{A,V}
|
2018-02-01 20:57:39 +00:00
|
|
|
|
Wi::A
|
|
|
|
|
Wh::A
|
|
|
|
|
b::V
|
2017-11-24 13:33:06 +00:00
|
|
|
|
h::V
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-01 20:57:39 +00:00
|
|
|
|
GRUCell(in, out; init = glorot_uniform) =
|
2019-03-08 12:13:58 +00:00
|
|
|
|
GRUCell(init(out * 3, in), init(out * 3, out),
|
2019-07-08 13:52:23 +00:00
|
|
|
|
init(out * 3), zeros(out))
|
2017-11-24 13:33:06 +00:00
|
|
|
|
|
|
|
|
|
function (m::GRUCell)(h, x)
|
2018-02-08 10:24:59 +00:00
|
|
|
|
b, o = m.b, size(h, 1)
|
2018-02-01 20:57:39 +00:00
|
|
|
|
gx, gh = m.Wi*x, m.Wh*h
|
|
|
|
|
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
|
|
|
|
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
|
|
|
|
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
2018-06-20 14:18:07 +00:00
|
|
|
|
h′ = (1 .- z).*h̃ .+ z.*h
|
2018-01-31 13:46:55 +00:00
|
|
|
|
return h′, h′
|
2017-11-24 13:33:06 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
hidden(m::GRUCell) = m.h
|
|
|
|
|
|
2018-07-12 21:43:11 +00:00
|
|
|
|
@treelike GRUCell
|
2017-11-24 13:33:06 +00:00
|
|
|
|
|
2018-02-01 20:57:39 +00:00
|
|
|
|
Base.show(io::IO, l::GRUCell) =
|
2018-02-22 00:21:48 +00:00
|
|
|
|
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
2017-11-24 13:33:06 +00:00
|
|
|
|
|
|
|
|
|
"""
|
2018-10-03 10:45:29 +00:00
|
|
|
|
GRU(in::Integer, out::Integer)
|
2017-11-24 13:33:06 +00:00
|
|
|
|
|
|
|
|
|
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
|
|
|
|
exhibits a longer memory span over sequences.
|
|
|
|
|
|
2019-04-25 11:04:03 +00:00
|
|
|
|
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
2017-11-24 13:33:06 +00:00
|
|
|
|
for a good overview of the internals.
|
|
|
|
|
"""
|
|
|
|
|
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|