diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 1ac7f9f5..01848379 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,6 +1,12 @@ # TODO: broadcasting cat combine(x, h) = vcat(x, h .* trues(1, size(x, 2))) +# Sequences + +struct Seq{T} + data::Vector{T} +end + # Stateful recurrence mutable struct Recur{T} @@ -18,6 +24,8 @@ end Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") +(m::Recur)(s::Seq) = Seq([m(x) for x in s.data]) + # Vanilla RNN struct RNNCell{D,V}