2017-09-05 06:29:31 +00:00
|
|
|
|
# TODO: broadcasting cat
|
2017-10-31 16:37:33 +00:00
|
|
|
|
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
|
|
|
|
|
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
|
|
|
|
|
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
|
2017-09-05 06:29:31 +00:00
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# Stateful recurrence
|
|
|
|
|
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
|
|
|
|
Recur(cell)
|
|
|
|
|
|
|
|
|
|
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
|
|
|
|
|
in the background. `cell` should be a model of the form:
|
|
|
|
|
|
|
|
|
|
h, y = cell(h, x...)
|
|
|
|
|
|
|
|
|
|
For example, here's a recurrent network that keeps a running total of its inputs.
|
|
|
|
|
|
2017-10-18 14:44:06 +00:00
|
|
|
|
```julia
|
|
|
|
|
accum(h, x) = (h+x, x)
|
|
|
|
|
rnn = Flux.Recur(accum, 0)
|
|
|
|
|
rnn(2) # 2
|
|
|
|
|
rnn(3) # 3
|
|
|
|
|
rnn.state # 5
|
|
|
|
|
rnn.(1:10) # apply to a sequence
|
|
|
|
|
rnn.state # 60
|
|
|
|
|
```
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
2017-09-03 06:12:44 +00:00
|
|
|
|
mutable struct Recur{T}
|
|
|
|
|
cell::T
|
2017-10-19 16:21:08 +00:00
|
|
|
|
init
|
2017-09-03 06:12:44 +00:00
|
|
|
|
state
|
|
|
|
|
end
|
|
|
|
|
|
2017-10-19 16:21:08 +00:00
|
|
|
|
Recur(m, h = hidden(m)) = Recur(m, h, h)
|
2017-09-03 06:12:44 +00:00
|
|
|
|
|
|
|
|
|
function (m::Recur)(xs...)
|
|
|
|
|
h, y = m.cell(m.state, xs...)
|
|
|
|
|
m.state = h
|
|
|
|
|
return y
|
|
|
|
|
end
|
|
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
|
treelike(Recur)
|
2017-09-07 04:05:02 +00:00
|
|
|
|
|
2017-09-05 23:25:34 +00:00
|
|
|
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
|
|
|
|
|
2017-10-19 16:21:08 +00:00
|
|
|
|
_truncate(x::AbstractArray) = Tracker.data(x)
|
2017-09-07 04:05:02 +00:00
|
|
|
|
_truncate(x::Tuple) = _truncate.(x)
|
|
|
|
|
|
2017-10-19 16:21:08 +00:00
|
|
|
|
"""
|
|
|
|
|
truncate!(rnn)
|
|
|
|
|
|
|
|
|
|
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
|
|
|
|
state is preserved. See also `reset!`.
|
|
|
|
|
|
|
|
|
|
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
|
|
|
|
|
|
|
|
|
rnn.state = Tracker.data(rnn.state)
|
|
|
|
|
"""
|
|
|
|
|
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
reset!(rnn)
|
|
|
|
|
|
|
|
|
|
Reset the hidden state of a recurrent layer back to its original value. See also
|
|
|
|
|
`truncate!`.
|
|
|
|
|
|
|
|
|
|
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
|
|
|
|
|
|
|
|
|
rnn.state = hidden(rnn.cell)
|
|
|
|
|
"""
|
|
|
|
|
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
|
2017-09-06 18:03:25 +00:00
|
|
|
|
|
2017-10-16 07:53:39 +00:00
|
|
|
|
flip(f, xs) = reverse(f.(reverse(xs)))
|
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# Vanilla RNN
|
|
|
|
|
|
|
|
|
|
struct RNNCell{D,V}
|
|
|
|
|
d::D
|
|
|
|
|
h::V
|
|
|
|
|
end
|
|
|
|
|
|
2017-12-05 07:47:03 +00:00
|
|
|
|
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
|
|
|
|
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
2017-09-03 06:12:44 +00:00
|
|
|
|
|
|
|
|
|
function (m::RNNCell)(h, x)
|
2017-09-05 06:29:31 +00:00
|
|
|
|
h = m.d(combine(x, h))
|
2017-09-03 06:12:44 +00:00
|
|
|
|
return h, h
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
hidden(m::RNNCell) = m.h
|
|
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
|
treelike(RNNCell)
|
2017-09-07 04:05:02 +00:00
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
function Base.show(io::IO, m::RNNCell)
|
|
|
|
|
print(io, "RNNCell(", m.d, ")")
|
|
|
|
|
end
|
|
|
|
|
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
|
|
|
|
RNN(in::Integer, out::Integer, σ = tanh)
|
|
|
|
|
|
|
|
|
|
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
|
|
|
|
|
output fed back into the input each time step.
|
|
|
|
|
"""
|
2017-09-05 23:25:34 +00:00
|
|
|
|
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# LSTM
|
|
|
|
|
|
2017-09-03 06:24:47 +00:00
|
|
|
|
struct LSTMCell{D1,D2,V}
|
|
|
|
|
forget::D1
|
|
|
|
|
input::D1
|
|
|
|
|
output::D1
|
|
|
|
|
cell::D2
|
|
|
|
|
h::V; c::V
|
2017-09-03 06:12:44 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-12-05 07:47:03 +00:00
|
|
|
|
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
|
|
|
|
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
|
|
|
|
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
|
|
|
|
param(initW(out)), param(initW(out)))
|
2017-09-07 03:09:32 +00:00
|
|
|
|
cell.forget.b.data .= 1
|
2017-09-05 06:42:32 +00:00
|
|
|
|
return cell
|
|
|
|
|
end
|
2017-09-03 06:12:44 +00:00
|
|
|
|
|
|
|
|
|
function (m::LSTMCell)(h_, x)
|
|
|
|
|
h, c = h_
|
2017-09-05 06:29:31 +00:00
|
|
|
|
x′ = combine(x, h)
|
2017-09-03 06:24:47 +00:00
|
|
|
|
forget, input, output, cell =
|
|
|
|
|
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
|
|
|
|
c = forget .* c .+ input .* cell
|
2017-09-03 06:12:44 +00:00
|
|
|
|
h = output .* tanh.(c)
|
|
|
|
|
return (h, c), h
|
|
|
|
|
end
|
|
|
|
|
|
2017-09-03 06:24:47 +00:00
|
|
|
|
hidden(m::LSTMCell) = (m.h, m.c)
|
2017-09-05 23:25:34 +00:00
|
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
|
treelike(LSTMCell)
|
2017-09-06 22:59:07 +00:00
|
|
|
|
|
2017-09-05 23:25:34 +00:00
|
|
|
|
Base.show(io::IO, m::LSTMCell) =
|
|
|
|
|
print(io, "LSTMCell(",
|
|
|
|
|
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
|
|
|
|
size(m.forget.W, 1), ')')
|
|
|
|
|
|
2017-10-18 14:30:05 +00:00
|
|
|
|
"""
|
|
|
|
|
LSTM(in::Integer, out::Integer, σ = tanh)
|
|
|
|
|
|
|
|
|
|
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
|
|
|
|
exhibits a longer memory span over sequences.
|
|
|
|
|
|
|
|
|
|
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
|
|
|
|
for a good overview of the internals.
|
|
|
|
|
"""
|
2017-09-05 23:25:34 +00:00
|
|
|
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
2017-11-24 13:33:06 +00:00
|
|
|
|
|
|
|
|
|
# GRU
|
|
|
|
|
|
|
|
|
|
struct GRUCell{D1,D2,V}
|
|
|
|
|
update::D1
|
|
|
|
|
reset::D1
|
|
|
|
|
candidate::D2
|
|
|
|
|
h::V
|
|
|
|
|
end
|
|
|
|
|
|
2018-01-10 14:11:52 +00:00
|
|
|
|
function GRUCell(in, out)
|
|
|
|
|
cell = GRUCell(Dense(in+out, out, σ),
|
|
|
|
|
Dense(in+out, out, σ),
|
|
|
|
|
Dense(in+out, out, tanh),
|
|
|
|
|
param(initn(out)))
|
2017-11-24 13:33:06 +00:00
|
|
|
|
return cell
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function (m::GRUCell)(h, x)
|
|
|
|
|
x′ = combine(x, h)
|
|
|
|
|
z = m.update(x′)
|
|
|
|
|
r = m.reset(x′)
|
|
|
|
|
h̃ = m.candidate(combine(r.*h, x))
|
|
|
|
|
h = (1.-z).*h .+ z.*h̃
|
|
|
|
|
return h, h
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
hidden(m::GRUCell) = m.h
|
|
|
|
|
|
|
|
|
|
treelike(GRUCell)
|
|
|
|
|
|
|
|
|
|
Base.show(io::IO, m::GRUCell) =
|
|
|
|
|
print(io, "GRUCell(",
|
|
|
|
|
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
|
|
|
|
size(m.update.W, 1), ')')
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
GRU(in::Integer, out::Integer, σ = tanh)
|
|
|
|
|
|
|
|
|
|
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
|
|
|
|
exhibits a longer memory span over sequences.
|
|
|
|
|
|
|
|
|
|
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
|
|
|
|
for a good overview of the internals.
|
|
|
|
|
"""
|
|
|
|
|
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|