basic seq functionality

This commit is contained in:
Mike J Innes 2017-09-06 14:03:25 -04:00
parent 2c8b7bc64b
commit 1946c46e29

View File

@ -1,6 +1,12 @@
# TODO: broadcasting cat
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Sequences
struct Seq{T}
data::Vector{T}
end
# Stateful recurrence
mutable struct Recur{T}
@ -18,6 +24,8 @@ end
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
# Vanilla RNN
struct RNNCell{D,V}