rm Over Seq

This commit is contained in:
Mike J Innes 2017-09-12 13:03:16 +01:00
parent 519f4c3c32
commit 972ecab9f9
3 changed files with 5 additions and 37 deletions

View File

@ -91,23 +91,14 @@ seq = [rand(10) for i = 1:10]
With `Recur`, applying our model to each element of a sequence is trivial:
```julia
map(m, seq) # returns a list of 5-element vectors
m.(seq) # returns a list of 5-element vectors
```
To make this a bit more convenient, Flux has the `Seq` type. This is just a list, but tagged so that we know it's meant to be used as a sequence of data points.
This works even when we've chain recurrent layers into a larger model.
```julia
seq = Seq([rand(10) for i = 1:10])
m(seq) # returns a new Seq of length 10
```
When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient.
You can get this behaviour more generally with the `Over` wrapper.
```julia
m = Over(Dense(10,5))
m(seq) # returns a new Seq of length 10
m = Chain(LSTM(10, 15), Dense(15, 5))
m.(seq)
```
## Truncating Gradients

View File

@ -7,7 +7,7 @@ module Flux
using Juno
using Lazy: @forward
export Chain, Dense, Seq, Over, RNN, LSTM,
export Chain, Dense, RNN, LSTM,
SGD, params
using NNlib

View File

@ -1,29 +1,6 @@
# TODO: broadcasting cat
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Sequences
struct Seq{T,A<:AbstractVector{T}}
data::A
end
Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs)
Seq(xs) = Seq(collect(xs))
Base.getindex(s::Seq, i) = s.data[i]
struct Over{T}
m::T
end
(m::Over)(xs...) = m.m(xs...)
(m::Over)(s::Seq) = Seq(map(m, s.data))
Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")")
Optimise.children(m::Over) = (m.m,)
# Stateful recurrence
mutable struct Recur{T}