rnn state reset
This commit is contained in:
parent
99a7697d13
commit
2a66545ef8
@ -112,3 +112,5 @@ truncate!(m)
|
|||||||
```
|
```
|
||||||
|
|
||||||
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
||||||
|
|
||||||
|
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
||||||
|
@ -25,10 +25,11 @@ rnn.state # 60
|
|||||||
"""
|
"""
|
||||||
mutable struct Recur{T}
|
mutable struct Recur{T}
|
||||||
cell::T
|
cell::T
|
||||||
|
init
|
||||||
state
|
state
|
||||||
end
|
end
|
||||||
|
|
||||||
Recur(m) = Recur(m, hidden(m))
|
Recur(m, h = hidden(m)) = Recur(m, h, h)
|
||||||
|
|
||||||
function (m::Recur)(xs...)
|
function (m::Recur)(xs...)
|
||||||
h, y = m.cell(m.state, xs...)
|
h, y = m.cell(m.state, xs...)
|
||||||
@ -40,12 +41,32 @@ treelike(Recur)
|
|||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
_truncate(x::AbstractArray) = x
|
_truncate(x::AbstractArray) = Tracker.data(x)
|
||||||
_truncate(x::TrackedArray) = x.data
|
|
||||||
_truncate(x::Tuple) = _truncate.(x)
|
_truncate(x::Tuple) = _truncate.(x)
|
||||||
|
|
||||||
truncate!(m) = foreach(truncate!, children(m))
|
"""
|
||||||
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
truncate!(rnn)
|
||||||
|
|
||||||
|
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
||||||
|
state is preserved. See also `reset!`.
|
||||||
|
|
||||||
|
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||||
|
|
||||||
|
rnn.state = Tracker.data(rnn.state)
|
||||||
|
"""
|
||||||
|
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
||||||
|
|
||||||
|
"""
|
||||||
|
reset!(rnn)
|
||||||
|
|
||||||
|
Reset the hidden state of a recurrent layer back to its original value. See also
|
||||||
|
`truncate!`.
|
||||||
|
|
||||||
|
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||||
|
|
||||||
|
rnn.state = hidden(rnn.cell)
|
||||||
|
"""
|
||||||
|
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
|
||||||
|
|
||||||
flip(f, xs) = reverse(f.(reverse(xs)))
|
flip(f, xs) = reverse(f.(reverse(xs)))
|
||||||
|
|
||||||
|
@ -20,15 +20,15 @@ export mapparams
|
|||||||
|
|
||||||
using DataFlow: OSet
|
using DataFlow: OSet
|
||||||
|
|
||||||
function forleaves(f, x; seen = OSet())
|
function prefor(f, x; seen = OSet())
|
||||||
x ∈ seen && return
|
x ∈ seen && return
|
||||||
push!(seen, x)
|
f(x)
|
||||||
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
|
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
function params(m)
|
function params(m)
|
||||||
ps = []
|
ps = []
|
||||||
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
|
prefor(p -> p isa TrackedArray && push!(ps, p), m)
|
||||||
return ps
|
return ps
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user