Flux.jl/src/layers/recurrent.jl

187 lines
4.0 KiB
Julia
Raw Normal View History

2018-08-18 15:50:52 +00:00
gate(h, n) = (1:h) .+ h*(n-1)
2018-02-01 20:57:39 +00:00
gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
2017-09-05 06:29:31 +00:00
2017-09-03 06:12:44 +00:00
# Stateful recurrence
2017-10-18 14:30:05 +00:00
"""
Recur(cell)
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
in the background. `cell` should be a model of the form:
h, y = cell(h, x...)
For example, here's a recurrent network that keeps a running total of its inputs.
2017-10-18 14:44:06 +00:00
```julia
accum(h, x) = (h+x, x)
rnn = Flux.Recur(accum, 0)
rnn(2) # 2
rnn(3) # 3
rnn.state # 5
rnn.(1:10) # apply to a sequence
rnn.state # 60
```
2017-10-18 14:30:05 +00:00
"""
2017-09-03 06:12:44 +00:00
mutable struct Recur{T}
cell::T
2017-10-19 16:21:08 +00:00
init
2017-09-03 06:12:44 +00:00
state
end
2017-10-19 16:21:08 +00:00
Recur(m, h = hidden(m)) = Recur(m, h, h)
2017-09-03 06:12:44 +00:00
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end
2018-07-12 21:43:11 +00:00
@treelike Recur cell, init
2017-09-07 04:05:02 +00:00
2017-09-05 23:25:34 +00:00
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
2017-10-19 16:21:08 +00:00
"""
reset!(rnn)
Reset the hidden state of a recurrent layer back to its original value. See also
`truncate!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
"""
2019-09-19 14:22:11 +00:00
reset!(m::Recur) = (m.state = m.init)
reset!(m) = foreach(reset!, functor(m)[1])
2017-09-06 18:03:25 +00:00
2017-10-16 07:53:39 +00:00
flip(f, xs) = reverse(f.(reverse(xs)))
2017-09-03 06:12:44 +00:00
# Vanilla RNN
2018-02-02 16:19:56 +00:00
mutable struct RNNCell{F,A,V}
2018-02-01 20:57:39 +00:00
σ::F
Wi::A
Wh::A
b::V
2017-09-03 06:12:44 +00:00
h::V
end
2018-02-01 20:57:39 +00:00
RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
2019-03-08 12:13:58 +00:00
RNNCell(σ, init(out, in), init(out, out),
2019-07-08 13:52:23 +00:00
init(out), zeros(out))
2017-09-03 06:12:44 +00:00
function (m::RNNCell)(h, x)
2018-02-01 20:57:39 +00:00
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
2017-09-03 06:12:44 +00:00
return h, h
end
hidden(m::RNNCell) = m.h
2018-07-12 21:43:11 +00:00
@treelike RNNCell
2017-09-07 04:05:02 +00:00
2018-02-01 20:57:39 +00:00
function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
2017-09-03 06:12:44 +00:00
end
2017-10-18 14:30:05 +00:00
"""
RNN(in::Integer, out::Integer, σ = tanh)
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
2017-09-05 23:25:34 +00:00
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
2017-09-03 06:12:44 +00:00
# LSTM
2018-02-02 16:19:56 +00:00
mutable struct LSTMCell{A,V}
2018-02-01 20:57:39 +00:00
Wi::A
Wh::A
b::V
h::V
c::V
2017-09-03 06:12:44 +00:00
end
2018-02-01 20:57:39 +00:00
function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
2019-03-08 12:13:58 +00:00
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
2019-07-08 13:52:23 +00:00
zeros(out), zeros(out))
2019-07-02 12:08:30 +00:00
cell.b[gate(out, 2)] .= 1
2017-09-05 06:42:32 +00:00
return cell
end
2017-09-03 06:12:44 +00:00
2018-09-05 14:55:08 +00:00
function (m::LSTMCell)((h, c), x)
2018-02-08 10:24:59 +00:00
b, o = m.b, size(h, 1)
2018-02-01 20:57:39 +00:00
g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1))
forget = σ.(gate(g, o, 2))
cell = tanh.(gate(g, o, 3))
output = σ.(gate(g, o, 4))
2017-09-03 06:24:47 +00:00
c = forget .* c .+ input .* cell
2018-02-01 20:57:39 +00:00
h = output .* tanh.(c)
return (h, c), h
2017-09-03 06:12:44 +00:00
end
2017-09-03 06:24:47 +00:00
hidden(m::LSTMCell) = (m.h, m.c)
2017-09-05 23:25:34 +00:00
2018-07-12 21:43:11 +00:00
@treelike LSTMCell
2017-09-06 22:59:07 +00:00
2018-02-01 20:57:39 +00:00
Base.show(io::IO, l::LSTMCell) =
2018-02-22 00:21:48 +00:00
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
2017-09-05 23:25:34 +00:00
2017-10-18 14:30:05 +00:00
"""
2018-10-03 10:45:29 +00:00
LSTM(in::Integer, out::Integer)
2017-10-18 14:30:05 +00:00
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
2017-10-18 14:30:05 +00:00
for a good overview of the internals.
"""
2017-09-05 23:25:34 +00:00
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
2017-11-24 13:33:06 +00:00
# GRU
2018-02-02 16:19:56 +00:00
mutable struct GRUCell{A,V}
2018-02-01 20:57:39 +00:00
Wi::A
Wh::A
b::V
2017-11-24 13:33:06 +00:00
h::V
end
2018-02-01 20:57:39 +00:00
GRUCell(in, out; init = glorot_uniform) =
2019-03-08 12:13:58 +00:00
GRUCell(init(out * 3, in), init(out * 3, out),
2019-07-08 13:52:23 +00:00
init(out * 3), zeros(out))
2017-11-24 13:33:06 +00:00
function (m::GRUCell)(h, x)
2018-02-08 10:24:59 +00:00
b, o = m.b, size(h, 1)
2018-02-01 20:57:39 +00:00
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
2018-06-20 14:18:07 +00:00
h = (1 .- z).* .+ z.*h
2018-01-31 13:46:55 +00:00
return h, h
2017-11-24 13:33:06 +00:00
end
hidden(m::GRUCell) = m.h
2018-07-12 21:43:11 +00:00
@treelike GRUCell
2017-11-24 13:33:06 +00:00
2018-02-01 20:57:39 +00:00
Base.show(io::IO, l::GRUCell) =
2018-02-22 00:21:48 +00:00
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
2017-11-24 13:33:06 +00:00
"""
2018-10-03 10:45:29 +00:00
GRU(in::Integer, out::Integer)
2017-11-24 13:33:06 +00:00
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
2017-11-24 13:33:06 +00:00
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))