export Recurrent, GatedRecurrent, LSTM @net type Recurrent Wxy; Wyy; by y function (x) y = tanh( x * Wxy .+ y{-1} * Wyy .+ by ) end end Recurrent(in, out; init = initn) = Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out)) @net type GatedRecurrent Wxr; Wyr; br Wxu; Wyu; bu Wxh; Wyh; bh y function (x) reset = σ( x * Wxr .+ y{-1} * Wyr .+ br ) update = σ( x * Wxu .+ y{-1} * Wyu .+ bu ) y′ = tanh( x * Wxh .+ (reset .* y{-1}) * Wyh .+ bh ) y = (1 .- update) .* y′ .+ update .* y{-1} end end GatedRecurrent(in, out; init = initn) = GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)..., zeros(Float32, (1, out))) @net type LSTM Wxf; Wyf; bf Wxi; Wyi; bi Wxo; Wyo; bo Wxc; Wyc; bc y; state function (x) # Gates forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf ) input = σ( x * Wxi .+ y{-1} * Wyi .+ bi ) output = σ( x * Wxo .+ y{-1} * Wyo .+ bo ) # State update and output state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc ) state = forget .* state{-1} .+ input .* state′ y = output .* tanh(state) end end LSTM(in, out; init = initn) = LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)..., zeros(Float32, (1, out)), zeros(Float32, (1, out)))