This commit is contained in:
Mike J Innes 2017-10-18 15:30:05 +01:00
parent 897f812055
commit fd249b773e
3 changed files with 45 additions and 2 deletions

View File

@ -10,7 +10,7 @@ makedocs(modules=[Flux, NNlib],
"Building Models" =>
["Basics" => "models/basics.md",
"Recurrence" => "models/recurrence.md",
"Layer Reference" => "models/layers.md"],
"Model Reference" => "models/layers.md"],
"Training Models" =>
["Optimisers" => "training/optimisers.md",
"Training" => "training/training.md"],

View File

@ -1,4 +1,4 @@
## Model Layers
## Layers
These core layers form the foundation of almost all neural networks.
@ -7,6 +7,16 @@ Chain
Dense
```
## Recurrent Cells
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
```@docs
RNN
LSTM
Recur
```
## Activation Functions
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.

View File

@ -3,6 +3,24 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Stateful recurrence
"""
Recur(cell)
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
in the background. `cell` should be a model of the form:
h, y = cell(h, x...)
For example, here's a recurrent network that keeps a running total of its inputs.
accum(h, x) = (h+x, x)
rnn = Flux.Recur(accum, 0)
rnn(2) # 2
rnn(3) # 3
rnn.state # 5
rnn.(1:10) # apply to a sequence
rnn.state # 60
"""
mutable struct Recur{T}
cell::T
state
@ -52,6 +70,12 @@ function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
"""
RNN(in::Integer, out::Integer, σ = tanh)
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM
@ -91,4 +115,13 @@ Base.show(io::IO, m::LSTMCell) =
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
"""
LSTM(in::Integer, out::Integer, σ = tanh)
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))