Merge pull request #112 from baggepinnen/gru
Implement Gated Recurrent Unit
This commit is contained in:
commit
b0b8c9dbd1
|
@ -7,7 +7,7 @@ module Flux
|
|||
using Juno, Requires
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, Conv2D,
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv2D,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||
param, params, mapleaves
|
||||
|
|
|
@ -150,3 +150,49 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
|||
for a good overview of the internals.
|
||||
"""
|
||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
|
||||
# GRU
|
||||
|
||||
struct GRUCell{D1,D2,V}
|
||||
update::D1
|
||||
reset::D1
|
||||
candidate::D2
|
||||
h::V
|
||||
end
|
||||
|
||||
function GRUCell(in, out)
|
||||
cell = GRUCell(Dense(in+out, out, σ),
|
||||
Dense(in+out, out, σ),
|
||||
Dense(in+out, out, tanh),
|
||||
param(initn(out)))
|
||||
return cell
|
||||
end
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
x′ = combine(x, h)
|
||||
z = m.update(x′)
|
||||
r = m.reset(x′)
|
||||
h̃ = m.candidate(combine(r.*h, x))
|
||||
h = (1.-z).*h .+ z.*h̃
|
||||
return h, h
|
||||
end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
treelike(GRUCell)
|
||||
|
||||
Base.show(io::IO, m::GRUCell) =
|
||||
print(io, "GRUCell(",
|
||||
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
||||
size(m.update.W, 1), ')')
|
||||
|
||||
"""
|
||||
GRU(in::Integer, out::Integer, σ = tanh)
|
||||
|
||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
"""
|
||||
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
||||
|
|
Loading…
Reference in New Issue