diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index ee9cad00..62a15157 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,3 +1,6 @@ +# TODO: broadcasting cat +combine(x, h) = vcat(x, h .* trues(1, size(x, 2))) + # Stateful recurrence mutable struct Recur{T} @@ -24,7 +27,7 @@ RNNCell(in::Integer, out::Integer, init = initn) = RNNCell(Dense(in+out, out, init = initn), track(initn(out))) function (m::RNNCell)(h, x) - h = m.d([x; h]) + h = m.d(combine(x, h)) return h, h end @@ -51,7 +54,7 @@ LSTMCell(in, out; init = initn) = function (m::LSTMCell)(h_, x) h, c = h_ - x′ = [x; h] + x′ = combine(x, h) forget, input, output, cell = m.forget(x′), m.input(x′), m.output(x′), m.cell(x′) c = forget .* c .+ input .* cell