diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 461fb0ca..b1b7cd11 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -34,8 +34,8 @@ struct RNNCell{D,V} h::V end -RNNCell(in::Integer, out::Integer; init = initn) = - RNNCell(Dense(in+out, out, init = initn), param(initn(out))) +RNNCell(in::Integer, out::Integer, σ=identity; init = initn) = + RNNCell(Dense(in+out, out, σ, init = init), param(init(out))) function (m::RNNCell)(h, x) h = m.d(combine(x, h))