rnn state reset
This commit is contained in:
parent
99a7697d13
commit
2a66545ef8
|
@ -112,3 +112,5 @@ truncate!(m)
|
|||
```
|
||||
|
||||
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
||||
|
||||
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
||||
|
|
|
@ -25,10 +25,11 @@ rnn.state # 60
|
|||
"""
|
||||
mutable struct Recur{T}
|
||||
cell::T
|
||||
init
|
||||
state
|
||||
end
|
||||
|
||||
Recur(m) = Recur(m, hidden(m))
|
||||
Recur(m, h = hidden(m)) = Recur(m, h, h)
|
||||
|
||||
function (m::Recur)(xs...)
|
||||
h, y = m.cell(m.state, xs...)
|
||||
|
@ -40,12 +41,32 @@ treelike(Recur)
|
|||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
_truncate(x::AbstractArray) = x
|
||||
_truncate(x::TrackedArray) = x.data
|
||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
truncate!(m) = foreach(truncate!, children(m))
|
||||
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
||||
"""
|
||||
truncate!(rnn)
|
||||
|
||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
||||
state is preserved. See also `reset!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = Tracker.data(rnn.state)
|
||||
"""
|
||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
||||
|
||||
"""
|
||||
reset!(rnn)
|
||||
|
||||
Reset the hidden state of a recurrent layer back to its original value. See also
|
||||
`truncate!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = hidden(rnn.cell)
|
||||
"""
|
||||
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
|
||||
|
||||
flip(f, xs) = reverse(f.(reverse(xs)))
|
||||
|
||||
|
|
|
@ -20,15 +20,15 @@ export mapparams
|
|||
|
||||
using DataFlow: OSet
|
||||
|
||||
function forleaves(f, x; seen = OSet())
|
||||
function prefor(f, x; seen = OSet())
|
||||
x ∈ seen && return
|
||||
push!(seen, x)
|
||||
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
return
|
||||
end
|
||||
|
||||
function params(m)
|
||||
ps = []
|
||||
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
|
||||
prefor(p -> p isa TrackedArray && push!(ps, p), m)
|
||||
return ps
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue