Implement Gated Recurrent Unit
This commit is contained in:
parent
9f5c4dd3e9
commit
fa718c7475
@ -7,7 +7,7 @@ module Flux
|
||||
using Juno, Requires
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
|
||||
export Chain, Dense, RNN, LSTM, GRU, Dropout, LayerNorm,
|
||||
SGD, ADAM, Momentum, Nesterov,
|
||||
param, params, mapleaves
|
||||
|
||||
|
@ -150,3 +150,51 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
"""
|
||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
|
||||
|
||||
|
||||
|
||||
# GRU
|
||||
|
||||
struct GRUCell{D1,D2,V}
|
||||
update::D1
|
||||
reset::D1
|
||||
candidate::D2
|
||||
h::V
|
||||
end
|
||||
|
||||
function GRUCell(in, out; init = initn)
|
||||
cell = GRUCell([Dense(in+out, out, σ, init = init) for _ = 1:2]...,
|
||||
Dense(in+out, out, tanh, init = init),
|
||||
param(init(out)))
|
||||
return cell
|
||||
end
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
x′ = combine(x, h)
|
||||
z = m.update(x′)
|
||||
r = m.reset(x′)
|
||||
h̃ = m.candidate(combine(r.*h, x))
|
||||
h = (1.-z).*h .+ z.*h̃
|
||||
return h, h
|
||||
end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
treelike(GRUCell)
|
||||
|
||||
Base.show(io::IO, m::GRUCell) =
|
||||
print(io, "GRUCell(",
|
||||
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
||||
size(m.update.W, 1), ')')
|
||||
|
||||
"""
|
||||
GRU(in::Integer, out::Integer, σ = tanh)
|
||||
|
||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
"""
|
||||
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
||||
|
Loading…
Reference in New Issue
Block a user