basic seq functionality
This commit is contained in:
parent
2c8b7bc64b
commit
1946c46e29
@ -1,6 +1,12 @@
|
||||
# TODO: broadcasting cat
|
||||
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||
|
||||
# Sequences
|
||||
|
||||
struct Seq{T}
|
||||
data::Vector{T}
|
||||
end
|
||||
|
||||
# Stateful recurrence
|
||||
|
||||
mutable struct Recur{T}
|
||||
@ -18,6 +24,8 @@ end
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
||||
|
||||
# Vanilla RNN
|
||||
|
||||
struct RNNCell{D,V}
|
||||
|
Loading…
Reference in New Issue
Block a user