Merge pull request #112 from baggepinnen/gru

Implement Gated Recurrent Unit
This commit is contained in:
Mike J Innes 2018-01-10 14:13:01 +00:00 committed by GitHub
commit b0b8c9dbd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 1 deletions

View File

@ -7,7 +7,7 @@ module Flux
using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM, Conv2D,
export Chain, Dense, RNN, LSTM, GRU, Conv2D,
Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves

View File

@ -150,3 +150,49 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
h::V
end
function GRUCell(in, out)
cell = GRUCell(Dense(in+out, out, σ),
Dense(in+out, out, σ),
Dense(in+out, out, tanh),
param(initn(out)))
return cell
end
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*
return h, h
end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
"""
GRU(in::Integer, out::Integer, σ = tanh)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))