rnn state reset

This commit is contained in:
Mike J Innes 2017-10-19 17:21:08 +01:00
parent 99a7697d13
commit 2a66545ef8
3 changed files with 32 additions and 9 deletions

View File

@ -112,3 +112,5 @@ truncate!(m)
```
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.

View File

@ -25,10 +25,11 @@ rnn.state # 60
"""
mutable struct Recur{T}
cell::T
init
state
end
Recur(m) = Recur(m, hidden(m))
Recur(m, h = hidden(m)) = Recur(m, h, h)
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
@ -40,12 +41,32 @@ treelike(Recur)
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
_truncate(x::AbstractArray) = x
_truncate(x::TrackedArray) = x.data
_truncate(x::AbstractArray) = Tracker.data(x)
_truncate(x::Tuple) = _truncate.(x)
truncate!(m) = foreach(truncate!, children(m))
truncate!(m::Recur) = (m.state = _truncate(m.state))
"""
truncate!(rnn)
Truncates the gradient of the hidden state in recurrent layers. The value of the
state is preserved. See also `reset!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = Tracker.data(rnn.state)
"""
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
"""
reset!(rnn)
Reset the hidden state of a recurrent layer back to its original value. See also
`truncate!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
"""
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
flip(f, xs) = reverse(f.(reverse(xs)))

View File

@ -20,15 +20,15 @@ export mapparams
using DataFlow: OSet
function forleaves(f, x; seen = OSet())
function prefor(f, x; seen = OSet())
x seen && return
push!(seen, x)
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))
return
end
function params(m)
ps = []
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
prefor(p -> p isa TrackedArray && push!(ps, p), m)
return ps
end