Implement Gated Recurrent Unit

This commit is contained in:
baggepinnen 2017-11-24 14:33:06 +01:00
parent 9f5c4dd3e9
commit fa718c7475
2 changed files with 49 additions and 1 deletions

View File

@ -7,7 +7,7 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, export Chain, Dense, RNN, LSTM, GRU, Dropout, LayerNorm,
SGD, ADAM, Momentum, Nesterov, SGD, ADAM, Momentum, Nesterov,
param, params, mapleaves param, params, mapleaves

View File

@ -150,3 +150,51 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals. for a good overview of the internals.
""" """
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
h::V
end
function GRUCell(in, out; init = initn)
cell = GRUCell([Dense(in+out, out, σ, init = init) for _ = 1:2]...,
Dense(in+out, out, tanh, init = init),
param(init(out)))
return cell
end
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*
return h, h
end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
"""
GRU(in::Integer, out::Integer, σ = tanh)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))