rnn docs
This commit is contained in:
parent
897f812055
commit
fd249b773e
@ -10,7 +10,7 @@ makedocs(modules=[Flux, NNlib],
|
|||||||
"Building Models" =>
|
"Building Models" =>
|
||||||
["Basics" => "models/basics.md",
|
["Basics" => "models/basics.md",
|
||||||
"Recurrence" => "models/recurrence.md",
|
"Recurrence" => "models/recurrence.md",
|
||||||
"Layer Reference" => "models/layers.md"],
|
"Model Reference" => "models/layers.md"],
|
||||||
"Training Models" =>
|
"Training Models" =>
|
||||||
["Optimisers" => "training/optimisers.md",
|
["Optimisers" => "training/optimisers.md",
|
||||||
"Training" => "training/training.md"],
|
"Training" => "training/training.md"],
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
## Model Layers
|
## Layers
|
||||||
|
|
||||||
These core layers form the foundation of almost all neural networks.
|
These core layers form the foundation of almost all neural networks.
|
||||||
|
|
||||||
@ -7,6 +7,16 @@ Chain
|
|||||||
Dense
|
Dense
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Recurrent Cells
|
||||||
|
|
||||||
|
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
RNN
|
||||||
|
LSTM
|
||||||
|
Recur
|
||||||
|
```
|
||||||
|
|
||||||
## Activation Functions
|
## Activation Functions
|
||||||
|
|
||||||
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
||||||
|
@ -3,6 +3,24 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
|||||||
|
|
||||||
# Stateful recurrence
|
# Stateful recurrence
|
||||||
|
|
||||||
|
"""
|
||||||
|
Recur(cell)
|
||||||
|
|
||||||
|
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
|
||||||
|
in the background. `cell` should be a model of the form:
|
||||||
|
|
||||||
|
h, y = cell(h, x...)
|
||||||
|
|
||||||
|
For example, here's a recurrent network that keeps a running total of its inputs.
|
||||||
|
|
||||||
|
accum(h, x) = (h+x, x)
|
||||||
|
rnn = Flux.Recur(accum, 0)
|
||||||
|
rnn(2) # 2
|
||||||
|
rnn(3) # 3
|
||||||
|
rnn.state # 5
|
||||||
|
rnn.(1:10) # apply to a sequence
|
||||||
|
rnn.state # 60
|
||||||
|
"""
|
||||||
mutable struct Recur{T}
|
mutable struct Recur{T}
|
||||||
cell::T
|
cell::T
|
||||||
state
|
state
|
||||||
@ -52,6 +70,12 @@ function Base.show(io::IO, m::RNNCell)
|
|||||||
print(io, "RNNCell(", m.d, ")")
|
print(io, "RNNCell(", m.d, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
RNN(in::Integer, out::Integer, σ = tanh)
|
||||||
|
|
||||||
|
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
|
||||||
|
output fed back into the input each time step.
|
||||||
|
"""
|
||||||
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||||
|
|
||||||
# LSTM
|
# LSTM
|
||||||
@ -91,4 +115,13 @@ Base.show(io::IO, m::LSTMCell) =
|
|||||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||||
size(m.forget.W, 1), ')')
|
size(m.forget.W, 1), ')')
|
||||||
|
|
||||||
|
"""
|
||||||
|
LSTM(in::Integer, out::Integer, σ = tanh)
|
||||||
|
|
||||||
|
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||||
|
exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
|
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
|
for a good overview of the internals.
|
||||||
|
"""
|
||||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||||
|
Loading…
Reference in New Issue
Block a user