Implement Gated Recurrent Unit
This commit is contained in:
parent
9f5c4dd3e9
commit
fa718c7475
@ -7,7 +7,7 @@ module Flux
|
|||||||
using Juno, Requires
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
|
export Chain, Dense, RNN, LSTM, GRU, Dropout, LayerNorm,
|
||||||
SGD, ADAM, Momentum, Nesterov,
|
SGD, ADAM, Momentum, Nesterov,
|
||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
|
||||||
|
@ -150,3 +150,51 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
|||||||
for a good overview of the internals.
|
for a good overview of the internals.
|
||||||
"""
|
"""
|
||||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# GRU
|
||||||
|
|
||||||
|
struct GRUCell{D1,D2,V}
|
||||||
|
update::D1
|
||||||
|
reset::D1
|
||||||
|
candidate::D2
|
||||||
|
h::V
|
||||||
|
end
|
||||||
|
|
||||||
|
function GRUCell(in, out; init = initn)
|
||||||
|
cell = GRUCell([Dense(in+out, out, σ, init = init) for _ = 1:2]...,
|
||||||
|
Dense(in+out, out, tanh, init = init),
|
||||||
|
param(init(out)))
|
||||||
|
return cell
|
||||||
|
end
|
||||||
|
|
||||||
|
function (m::GRUCell)(h, x)
|
||||||
|
x′ = combine(x, h)
|
||||||
|
z = m.update(x′)
|
||||||
|
r = m.reset(x′)
|
||||||
|
h̃ = m.candidate(combine(r.*h, x))
|
||||||
|
h = (1.-z).*h .+ z.*h̃
|
||||||
|
return h, h
|
||||||
|
end
|
||||||
|
|
||||||
|
hidden(m::GRUCell) = m.h
|
||||||
|
|
||||||
|
treelike(GRUCell)
|
||||||
|
|
||||||
|
Base.show(io::IO, m::GRUCell) =
|
||||||
|
print(io, "GRUCell(",
|
||||||
|
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
||||||
|
size(m.update.W, 1), ')')
|
||||||
|
|
||||||
|
"""
|
||||||
|
GRU(in::Integer, out::Integer, σ = tanh)
|
||||||
|
|
||||||
|
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||||
|
exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
|
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
|
for a good overview of the internals.
|
||||||
|
"""
|
||||||
|
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
||||||
|
Loading…
Reference in New Issue
Block a user