Merge pull request #112 from baggepinnen/gru

Implement Gated Recurrent Unit
This commit is contained in:
Mike J Innes 2018-01-10 14:13:01 +00:00 committed by GitHub
commit b0b8c9dbd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 1 deletions

View File

@ -7,7 +7,7 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, Conv2D, export Chain, Dense, RNN, LSTM, GRU, Conv2D,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves param, params, mapleaves

View File

@ -150,3 +150,49 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals. for a good overview of the internals.
""" """
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
h::V
end
function GRUCell(in, out)
cell = GRUCell(Dense(in+out, out, σ),
Dense(in+out, out, σ),
Dense(in+out, out, tanh),
param(initn(out)))
return cell
end
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*
return h, h
end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
"""
GRU(in::Integer, out::Integer, σ = tanh)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))