From fa718c7475375be21c05761d11c41041bbf1dda8 Mon Sep 17 00:00:00 2001 From: baggepinnen Date: Fri, 24 Nov 2017 14:33:06 +0100 Subject: [PATCH] Implement Gated Recurrent Unit --- src/Flux.jl | 2 +- src/layers/recurrent.jl | 48 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index df4b1636..b04afc12 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, +export Chain, Dense, RNN, LSTM, GRU, Dropout, LayerNorm, SGD, ADAM, Momentum, Nesterov, param, params, mapleaves diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 599776ce..47aef83f 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -150,3 +150,51 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) + + + + +# GRU + +struct GRUCell{D1,D2,V} + update::D1 + reset::D1 + candidate::D2 + h::V +end + +function GRUCell(in, out; init = initn) + cell = GRUCell([Dense(in+out, out, σ, init = init) for _ = 1:2]..., + Dense(in+out, out, tanh, init = init), + param(init(out))) + return cell +end + +function (m::GRUCell)(h, x) + x′ = combine(x, h) + z = m.update(x′) + r = m.reset(x′) + h̃ = m.candidate(combine(r.*h, x)) + h = (1.-z).*h .+ z.*h̃ + return h, h +end + +hidden(m::GRUCell) = m.h + +treelike(GRUCell) + +Base.show(io::IO, m::GRUCell) = + print(io, "GRUCell(", + size(m.update.W, 2) - size(m.update.W, 1), ", ", + size(m.update.W, 1), ')') + +""" + GRU(in::Integer, out::Integer, σ = tanh) + +Gated Recurrent Unit layer. Behaves like an RNN but generally +exhibits a longer memory span over sequences. + +See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) +for a good overview of the internals. +""" +GRU(a...; ka...) = Recur(GRUCell(a...; ka...))