rm Over Seq
This commit is contained in:
parent
519f4c3c32
commit
972ecab9f9
@ -91,23 +91,14 @@ seq = [rand(10) for i = 1:10]
|
|||||||
With `Recur`, applying our model to each element of a sequence is trivial:
|
With `Recur`, applying our model to each element of a sequence is trivial:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
map(m, seq) # returns a list of 5-element vectors
|
m.(seq) # returns a list of 5-element vectors
|
||||||
```
|
```
|
||||||
|
|
||||||
To make this a bit more convenient, Flux has the `Seq` type. This is just a list, but tagged so that we know it's meant to be used as a sequence of data points.
|
This works even when we've chain recurrent layers into a larger model.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
seq = Seq([rand(10) for i = 1:10])
|
m = Chain(LSTM(10, 15), Dense(15, 5))
|
||||||
m(seq) # returns a new Seq of length 10
|
m.(seq)
|
||||||
```
|
|
||||||
|
|
||||||
When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient.
|
|
||||||
|
|
||||||
You can get this behaviour more generally with the `Over` wrapper.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
m = Over(Dense(10,5))
|
|
||||||
m(seq) # returns a new Seq of length 10
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Truncating Gradients
|
## Truncating Gradients
|
||||||
|
@ -7,7 +7,7 @@ module Flux
|
|||||||
using Juno
|
using Juno
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, Seq, Over, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM,
|
||||||
SGD, params
|
SGD, params
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
|
@ -1,29 +1,6 @@
|
|||||||
# TODO: broadcasting cat
|
# TODO: broadcasting cat
|
||||||
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||||
|
|
||||||
# Sequences
|
|
||||||
|
|
||||||
struct Seq{T,A<:AbstractVector{T}}
|
|
||||||
data::A
|
|
||||||
end
|
|
||||||
|
|
||||||
Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs)
|
|
||||||
|
|
||||||
Seq(xs) = Seq(collect(xs))
|
|
||||||
|
|
||||||
Base.getindex(s::Seq, i) = s.data[i]
|
|
||||||
|
|
||||||
struct Over{T}
|
|
||||||
m::T
|
|
||||||
end
|
|
||||||
|
|
||||||
(m::Over)(xs...) = m.m(xs...)
|
|
||||||
(m::Over)(s::Seq) = Seq(map(m, s.data))
|
|
||||||
|
|
||||||
Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")")
|
|
||||||
|
|
||||||
Optimise.children(m::Over) = (m.m,)
|
|
||||||
|
|
||||||
# Stateful recurrence
|
# Stateful recurrence
|
||||||
|
|
||||||
mutable struct Recur{T}
|
mutable struct Recur{T}
|
||||||
|
Loading…
Reference in New Issue
Block a user