diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index b3c37406..89f29f1e 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,15 +1,12 @@ export Recurrent @net type Recurrent - Wxh; Whh; Why - bh; by - hidden + Wxy; Wyy; by + y function (x) - hidden = σ( x * Wxh + hidden * Whh + bh ) - y = hidden * Why + by + y = tanh( x * Wxy + y * Wyy + by ) end end -Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) = - Recurrent(init((in, hidden)), init((hidden, hidden)), init((hidden, out)), - init(hidden), init(out), zeros(Float32, hidden)) +Recurrent(in, out; init = initn) = + Recurrent(init((in, out)), init((out, out)), init(out), init(out))