rm chainseq
This commit is contained in:
parent
c80fb999ff
commit
7041ab9960
|
@ -103,6 +103,13 @@ m(seq) # returns a new Seq of length 10
|
|||
|
||||
When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient.
|
||||
|
||||
You can get this behaviour more generally with the `Over` wrapper.
|
||||
|
||||
```julia
|
||||
m = Over(Dense(10,5))
|
||||
m(seq) # returns a new Seq of length 10
|
||||
```
|
||||
|
||||
## Truncating Gradients
|
||||
|
||||
By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.
|
||||
|
|
|
@ -7,7 +7,7 @@ module Flux
|
|||
using Juno
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, Seq, ChainSeq, RNN, LSTM,
|
||||
export Chain, Dense, Seq, Over, RNN, LSTM,
|
||||
SGD, params
|
||||
|
||||
using NNlib
|
||||
|
|
|
@ -13,26 +13,14 @@ Seq(xs) = Seq(collect(xs))
|
|||
|
||||
Base.getindex(s::Seq, i) = s.data[i]
|
||||
|
||||
type ChainSeq
|
||||
layers::Vector{Any}
|
||||
ChainSeq(xs...) = new([xs...])
|
||||
struct Over{T}
|
||||
m::T
|
||||
end
|
||||
|
||||
@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward ChainSeq.layers Base.start, Base.next, Base.done
|
||||
(m::Over)(xs...) = m.m(xs...)
|
||||
(m::Over)(xs::Seq) = Seq(map(m, xs.data))
|
||||
|
||||
Optimise.children(c::ChainSeq) = c.layers
|
||||
|
||||
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
|
||||
|
||||
Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
function Base.show(io::IO, c::ChainSeq)
|
||||
print(io, "ChainSeq(")
|
||||
join(io, c.layers, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")")
|
||||
|
||||
# Stateful recurrence
|
||||
|
||||
|
@ -49,7 +37,7 @@ function (m::Recur)(xs...)
|
|||
return y
|
||||
end
|
||||
|
||||
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
||||
(m::Recur)(s::Seq) = Seq(map(m, x.data))
|
||||
|
||||
Optimise.children(m::Recur) = (m.cell,)
|
||||
|
||||
|
|
Loading…
Reference in New Issue