rnn state reset

Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.

mutable struct Recur{T}
Recur(m) = Recur(m, hidden(m))
Recur(m, h = hidden(m)) = Recur(m, h, h)
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
_truncate(x::AbstractArray) = x
_truncate(x::TrackedArray) = x.data
_truncate(x::AbstractArray) = Tracker.data(x)
_truncate(x::Tuple) = _truncate.(x)
truncate!(m) = foreach(truncate!, children(m))
truncate!(m::Recur) = (m.state = _truncate(m.state))
Truncates the gradient of the hidden state in recurrent layers. The value of the
state is preserved. See also `reset!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = Tracker.data(rnn.state)
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
Reset the hidden state of a recurrent layer back to its original value. See also
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
flip(f, xs) = reverse(f.(reverse(xs)))

using DataFlow: OSet
function forleaves(f, x; seen = OSet())
function prefor(f, x; seen = OSet())
x seen && return
push!(seen, x)
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
foreach(x -> prefor(f, x, seen = seen), children(x))
function params(m)
ps = []
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
prefor(p -> p isa TrackedArray && push!(ps, p), m)
return ps