basic seq functionality
This commit is contained in:
parent
2c8b7bc64b
commit
1946c46e29
@ -1,6 +1,12 @@
|
|||||||
# TODO: broadcasting cat
|
# TODO: broadcasting cat
|
||||||
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||||
|
|
||||||
|
# Sequences
|
||||||
|
|
||||||
|
struct Seq{T}
|
||||||
|
data::Vector{T}
|
||||||
|
end
|
||||||
|
|
||||||
# Stateful recurrence
|
# Stateful recurrence
|
||||||
|
|
||||||
mutable struct Recur{T}
|
mutable struct Recur{T}
|
||||||
@ -18,6 +24,8 @@ end
|
|||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
|
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
||||||
|
|
||||||
# Vanilla RNN
|
# Vanilla RNN
|
||||||
|
|
||||||
struct RNNCell{D,V}
|
struct RNNCell{D,V}
|
||||||
|
Loading…
Reference in New Issue
Block a user