rnn docs
This commit is contained in:
parent
897f812055
commit
fd249b773e
|
@ -10,7 +10,7 @@ makedocs(modules=[Flux, NNlib],
|
|||
"Building Models" =>
|
||||
["Basics" => "models/basics.md",
|
||||
"Recurrence" => "models/recurrence.md",
|
||||
"Layer Reference" => "models/layers.md"],
|
||||
"Model Reference" => "models/layers.md"],
|
||||
"Training Models" =>
|
||||
["Optimisers" => "training/optimisers.md",
|
||||
"Training" => "training/training.md"],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
## Model Layers
|
||||
## Layers
|
||||
|
||||
These core layers form the foundation of almost all neural networks.
|
||||
|
||||
|
@ -7,6 +7,16 @@ Chain
|
|||
Dense
|
||||
```
|
||||
|
||||
## Recurrent Cells
|
||||
|
||||
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
|
||||
|
||||
```@docs
|
||||
RNN
|
||||
LSTM
|
||||
Recur
|
||||
```
|
||||
|
||||
## Activation Functions
|
||||
|
||||
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
||||
|
|
|
@ -3,6 +3,24 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
|||
|
||||
# Stateful recurrence
|
||||
|
||||
"""
|
||||
Recur(cell)
|
||||
|
||||
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
|
||||
in the background. `cell` should be a model of the form:
|
||||
|
||||
h, y = cell(h, x...)
|
||||
|
||||
For example, here's a recurrent network that keeps a running total of its inputs.
|
||||
|
||||
accum(h, x) = (h+x, x)
|
||||
rnn = Flux.Recur(accum, 0)
|
||||
rnn(2) # 2
|
||||
rnn(3) # 3
|
||||
rnn.state # 5
|
||||
rnn.(1:10) # apply to a sequence
|
||||
rnn.state # 60
|
||||
"""
|
||||
mutable struct Recur{T}
|
||||
cell::T
|
||||
state
|
||||
|
@ -52,6 +70,12 @@ function Base.show(io::IO, m::RNNCell)
|
|||
print(io, "RNNCell(", m.d, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
RNN(in::Integer, out::Integer, σ = tanh)
|
||||
|
||||
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
|
||||
output fed back into the input each time step.
|
||||
"""
|
||||
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||
|
||||
# LSTM
|
||||
|
@ -91,4 +115,13 @@ Base.show(io::IO, m::LSTMCell) =
|
|||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||
size(m.forget.W, 1), ')')
|
||||
|
||||
"""
|
||||
LSTM(in::Integer, out::Integer, σ = tanh)
|
||||
|
||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
"""
|
||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
|
|
Loading…
Reference in New Issue