rm Over Seq
This commit is contained in:
parent
519f4c3c32
commit
972ecab9f9
|
@ -91,23 +91,14 @@ seq = [rand(10) for i = 1:10]
|
|||
With `Recur`, applying our model to each element of a sequence is trivial:
|
||||
|
||||
```julia
|
||||
map(m, seq) # returns a list of 5-element vectors
|
||||
m.(seq) # returns a list of 5-element vectors
|
||||
```
|
||||
|
||||
To make this a bit more convenient, Flux has the `Seq` type. This is just a list, but tagged so that we know it's meant to be used as a sequence of data points.
|
||||
This works even when we've chain recurrent layers into a larger model.
|
||||
|
||||
```julia
|
||||
seq = Seq([rand(10) for i = 1:10])
|
||||
m(seq) # returns a new Seq of length 10
|
||||
```
|
||||
|
||||
When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient.
|
||||
|
||||
You can get this behaviour more generally with the `Over` wrapper.
|
||||
|
||||
```julia
|
||||
m = Over(Dense(10,5))
|
||||
m(seq) # returns a new Seq of length 10
|
||||
m = Chain(LSTM(10, 15), Dense(15, 5))
|
||||
m.(seq)
|
||||
```
|
||||
|
||||
## Truncating Gradients
|
||||
|
|
|
@ -7,7 +7,7 @@ module Flux
|
|||
using Juno
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, Seq, Over, RNN, LSTM,
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
SGD, params
|
||||
|
||||
using NNlib
|
||||
|
|
|
@ -1,29 +1,6 @@
|
|||
# TODO: broadcasting cat
|
||||
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||
|
||||
# Sequences
|
||||
|
||||
struct Seq{T,A<:AbstractVector{T}}
|
||||
data::A
|
||||
end
|
||||
|
||||
Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs)
|
||||
|
||||
Seq(xs) = Seq(collect(xs))
|
||||
|
||||
Base.getindex(s::Seq, i) = s.data[i]
|
||||
|
||||
struct Over{T}
|
||||
m::T
|
||||
end
|
||||
|
||||
(m::Over)(xs...) = m.m(xs...)
|
||||
(m::Over)(s::Seq) = Seq(map(m, s.data))
|
||||
|
||||
Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")")
|
||||
|
||||
Optimise.children(m::Over) = (m.m,)
|
||||
|
||||
# Stateful recurrence
|
||||
|
||||
mutable struct Recur{T}
|
||||
|
|
Loading…
Reference in New Issue