Implement Gated Recurrent Unit

This commit is contained in:
baggepinnen 2017-11-24 14:33:06 +01:00
parent 9f5c4dd3e9
commit fa718c7475
2 changed files with 49 additions and 1 deletions

View File

@ -7,7 +7,7 @@ module Flux
using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
export Chain, Dense, RNN, LSTM, GRU, Dropout, LayerNorm,
SGD, ADAM, Momentum, Nesterov,
param, params, mapleaves

View File

@ -150,3 +150,51 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
h::V
end
function GRUCell(in, out; init = initn)
cell = GRUCell([Dense(in+out, out, σ, init = init) for _ = 1:2]...,
Dense(in+out, out, tanh, init = init),
param(init(out)))
return cell
end
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*
return h, h
end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
"""
GRU(in::Integer, out::Integer, σ = tanh)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))