rm chainseq

This commit is contained in:
Mike J Innes 2017-09-11 14:02:43 +01:00
parent c80fb999ff
commit 7041ab9960
3 changed files with 14 additions and 19 deletions

View File

@ -103,6 +103,13 @@ m(seq) # returns a new Seq of length 10
When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient. When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient.
You can get this behaviour more generally with the `Over` wrapper.
```julia
m = Over(Dense(10,5))
m(seq) # returns a new Seq of length 10
```
## Truncating Gradients ## Truncating Gradients
By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive. By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive.

View File

@ -7,7 +7,7 @@ module Flux
using Juno using Juno
using Lazy: @forward using Lazy: @forward
export Chain, Dense, Seq, ChainSeq, RNN, LSTM, export Chain, Dense, Seq, Over, RNN, LSTM,
SGD, params SGD, params
using NNlib using NNlib

View File

@ -13,26 +13,14 @@ Seq(xs) = Seq(collect(xs))
Base.getindex(s::Seq, i) = s.data[i] Base.getindex(s::Seq, i) = s.data[i]
type ChainSeq struct Over{T}
layers::Vector{Any} m::T
ChainSeq(xs...) = new([xs...])
end end
@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! (m::Over)(xs...) = m.m(xs...)
@forward ChainSeq.layers Base.start, Base.next, Base.done (m::Over)(xs::Seq) = Seq(map(m, xs.data))
Optimise.children(c::ChainSeq) = c.layers Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")")
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...)
function Base.show(io::IO, c::ChainSeq)
print(io, "ChainSeq(")
join(io, c.layers, ", ")
print(io, ")")
end
# Stateful recurrence # Stateful recurrence
@ -49,7 +37,7 @@ function (m::Recur)(xs...)
return y return y
end end
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data]) (m::Recur)(s::Seq) = Seq(map(m, x.data))
Optimise.children(m::Recur) = (m.cell,) Optimise.children(m::Recur) = (m.cell,)