diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index b2754d1e..f8fb8943 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -32,3 +32,20 @@ end LSTM(in, out; init = initn) = LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)..., zeros(Float32, out), zeros(Float32, out)) + +@net type GatedRecurrent + Wxr; Wyr; br + Wxu; Wyu; bu + Wxh; Wyh; bh + state + function (x) + reset = σ( x * Wxr + y * Wyr + br ) + update = σ( x * Wxu + y * Wyu + bu ) + state′ = tanh( x * Wxh + (reset .* y) * Wyh + bh ) + state = (1 .- update) .* state′ + update .* y + end +end + +GatedRecurrent(in, out; init = initn) = + GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)..., + zeros(Float32, out))