From 85415d4244f7a91c0bfa86c1c1f5d7008c273a1d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 1 Nov 2016 14:42:41 +0000 Subject: [PATCH] throw GRU together --- src/layers/recurrent.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index b2754d1e..f8fb8943 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -32,3 +32,20 @@ end LSTM(in, out; init = initn) = LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)..., zeros(Float32, out), zeros(Float32, out)) + +@net type GatedRecurrent + Wxr; Wyr; br + Wxu; Wyu; bu + Wxh; Wyh; bh + state + function (x) + reset = σ( x * Wxr + y * Wyr + br ) + update = σ( x * Wxu + y * Wyu + bu ) + state′ = tanh( x * Wxh + (reset .* y) * Wyh + bh ) + state = (1 .- update) .* state′ + update .* y + end +end + +GatedRecurrent(in, out; init = initn) = + GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)..., + zeros(Float32, out))