Flux.jl/src/layers/recurrent.jl

199 lines
4.4 KiB
Julia
Raw Normal View History

2017-09-05 06:29:31 +00:00
# TODO: broadcasting cat
2017-10-31 16:37:33 +00:00
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
2017-09-05 06:29:31 +00:00
2017-09-03 06:12:44 +00:00
# Stateful recurrence
2017-10-18 14:30:05 +00:00
"""
Recur(cell)
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
in the background. `cell` should be a model of the form:
h, y = cell(h, x...)
For example, here's a recurrent network that keeps a running total of its inputs.
2017-10-18 14:44:06 +00:00
```julia
accum(h, x) = (h+x, x)
rnn = Flux.Recur(accum, 0)
rnn(2) # 2
rnn(3) # 3
rnn.state # 5
rnn.(1:10) # apply to a sequence
rnn.state # 60
```
2017-10-18 14:30:05 +00:00
"""
2017-09-03 06:12:44 +00:00
mutable struct Recur{T}
cell::T
2017-10-19 16:21:08 +00:00
init
2017-09-03 06:12:44 +00:00
state
end
2017-10-19 16:21:08 +00:00
Recur(m, h = hidden(m)) = Recur(m, h, h)
2017-09-03 06:12:44 +00:00
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end
2017-09-27 20:11:21 +00:00
treelike(Recur)
2017-09-07 04:05:02 +00:00
2017-09-05 23:25:34 +00:00
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
2017-10-19 16:21:08 +00:00
_truncate(x::AbstractArray) = Tracker.data(x)
2017-09-07 04:05:02 +00:00
_truncate(x::Tuple) = _truncate.(x)
2017-10-19 16:21:08 +00:00
"""
truncate!(rnn)
Truncates the gradient of the hidden state in recurrent layers. The value of the
state is preserved. See also `reset!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = Tracker.data(rnn.state)
"""
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
"""
reset!(rnn)
Reset the hidden state of a recurrent layer back to its original value. See also
`truncate!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
"""
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
2017-09-06 18:03:25 +00:00
2017-10-16 07:53:39 +00:00
flip(f, xs) = reverse(f.(reverse(xs)))
2017-09-03 06:12:44 +00:00
# Vanilla RNN
struct RNNCell{D,V}
d::D
h::V
end
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
2017-09-03 06:12:44 +00:00
function (m::RNNCell)(h, x)
2017-09-05 06:29:31 +00:00
h = m.d(combine(x, h))
2017-09-03 06:12:44 +00:00
return h, h
end
hidden(m::RNNCell) = m.h
2017-09-27 20:11:21 +00:00
treelike(RNNCell)
2017-09-07 04:05:02 +00:00
2017-09-03 06:12:44 +00:00
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
2017-10-18 14:30:05 +00:00
"""
RNN(in::Integer, out::Integer, σ = tanh)
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
2017-09-05 23:25:34 +00:00
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
2017-09-03 06:12:44 +00:00
# LSTM
2017-09-03 06:24:47 +00:00
struct LSTMCell{D1,D2,V}
forget::D1
input::D1
output::D1
cell::D2
h::V; c::V
2017-09-03 06:12:44 +00:00
end
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
Dense(in+out, out, tanh, initW = initW, initb = initb),
param(initW(out)), param(initW(out)))
2017-09-07 03:09:32 +00:00
cell.forget.b.data .= 1
2017-09-05 06:42:32 +00:00
return cell
end
2017-09-03 06:12:44 +00:00
function (m::LSTMCell)(h_, x)
h, c = h_
2017-09-05 06:29:31 +00:00
x = combine(x, h)
2017-09-03 06:24:47 +00:00
forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x)
c = forget .* c .+ input .* cell
2017-09-03 06:12:44 +00:00
h = output .* tanh.(c)
return (h, c), h
end
2017-09-03 06:24:47 +00:00
hidden(m::LSTMCell) = (m.h, m.c)
2017-09-05 23:25:34 +00:00
2017-09-27 20:11:21 +00:00
treelike(LSTMCell)
2017-09-06 22:59:07 +00:00
2017-09-05 23:25:34 +00:00
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
2017-10-18 14:30:05 +00:00
"""
LSTM(in::Integer, out::Integer, σ = tanh)
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
2017-09-05 23:25:34 +00:00
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
2017-11-24 13:33:06 +00:00
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
h::V
end
2018-01-10 14:11:52 +00:00
function GRUCell(in, out)
cell = GRUCell(Dense(in+out, out, σ),
Dense(in+out, out, σ),
Dense(in+out, out, tanh),
param(initn(out)))
2017-11-24 13:33:06 +00:00
return cell
end
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*
return h, h
end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
"""
GRU(in::Integer, out::Integer, σ = tanh)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))