diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index b04ed5da..e4eb0c3d 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -160,10 +160,11 @@ struct GRUCell{D1,D2,V} h::V end -function GRUCell(in, out; init = initn) - cell = GRUCell([Dense(in+out, out, σ, init = init) for _ = 1:2]..., - Dense(in+out, out, tanh, init = init), - param(init(out))) +function GRUCell(in, out) + cell = GRUCell(Dense(in+out, out, σ), + Dense(in+out, out, σ), + Dense(in+out, out, tanh), + param(initn(out))) return cell end