rm chainseq

This commit is contained in:
Mike J Innes 2017-09-11 14:02:43 +01:00
parent c80fb999ff
commit 7041ab9960
3 changed files with 14 additions and 19 deletions

View File

@ -103,6 +103,13 @@ m(seq) # returns a new Seq of length 10
When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient.
You can get this behaviour more generally with the `Over` wrapper.
```julia
m = Over(Dense(10,5))
m(seq) # returns a new Seq of length 10
```
## Truncating Gradients
By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive.

View File

@ -7,7 +7,7 @@ module Flux
using Juno
using Lazy: @forward
export Chain, Dense, Seq, ChainSeq, RNN, LSTM,
export Chain, Dense, Seq, Over, RNN, LSTM,
SGD, params
using NNlib

View File

@ -13,26 +13,14 @@ Seq(xs) = Seq(collect(xs))
Base.getindex(s::Seq, i) = s.data[i]
type ChainSeq
layers::Vector{Any}
ChainSeq(xs...) = new([xs...])
struct Over{T}
m::T
end
@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward ChainSeq.layers Base.start, Base.next, Base.done
(m::Over)(xs...) = m.m(xs...)
(m::Over)(xs::Seq) = Seq(map(m, xs.data))
Optimise.children(c::ChainSeq) = c.layers
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...)
function Base.show(io::IO, c::ChainSeq)
print(io, "ChainSeq(")
join(io, c.layers, ", ")
print(io, ")")
end
Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")")
# Stateful recurrence
@ -49,7 +37,7 @@ function (m::Recur)(xs...)
return y
end
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
(m::Recur)(s::Seq) = Seq(map(m, x.data))
Optimise.children(m::Recur) = (m.cell,)