Merge pull request #112 from baggepinnen/gru
Implement Gated Recurrent Unit
This commit is contained in:
commit
b0b8c9dbd1
@ -7,7 +7,7 @@ module Flux
|
|||||||
using Juno, Requires
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, Conv2D,
|
export Chain, Dense, RNN, LSTM, GRU, Conv2D,
|
||||||
Dropout, LayerNorm, BatchNorm,
|
Dropout, LayerNorm, BatchNorm,
|
||||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
@ -150,3 +150,49 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
|||||||
for a good overview of the internals.
|
for a good overview of the internals.
|
||||||
"""
|
"""
|
||||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||||
|
|
||||||
|
# GRU
|
||||||
|
|
||||||
|
struct GRUCell{D1,D2,V}
|
||||||
|
update::D1
|
||||||
|
reset::D1
|
||||||
|
candidate::D2
|
||||||
|
h::V
|
||||||
|
end
|
||||||
|
|
||||||
|
function GRUCell(in, out)
|
||||||
|
cell = GRUCell(Dense(in+out, out, σ),
|
||||||
|
Dense(in+out, out, σ),
|
||||||
|
Dense(in+out, out, tanh),
|
||||||
|
param(initn(out)))
|
||||||
|
return cell
|
||||||
|
end
|
||||||
|
|
||||||
|
function (m::GRUCell)(h, x)
|
||||||
|
x′ = combine(x, h)
|
||||||
|
z = m.update(x′)
|
||||||
|
r = m.reset(x′)
|
||||||
|
h̃ = m.candidate(combine(r.*h, x))
|
||||||
|
h = (1.-z).*h .+ z.*h̃
|
||||||
|
return h, h
|
||||||
|
end
|
||||||
|
|
||||||
|
hidden(m::GRUCell) = m.h
|
||||||
|
|
||||||
|
treelike(GRUCell)
|
||||||
|
|
||||||
|
Base.show(io::IO, m::GRUCell) =
|
||||||
|
print(io, "GRUCell(",
|
||||||
|
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
||||||
|
size(m.update.W, 1), ')')
|
||||||
|
|
||||||
|
"""
|
||||||
|
GRU(in::Integer, out::Integer, σ = tanh)
|
||||||
|
|
||||||
|
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||||
|
exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
|
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
|
for a good overview of the internals.
|
||||||
|
"""
|
||||||
|
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
||||||
|
Loading…
Reference in New Issue
Block a user