seq stuff
This commit is contained in:
parent
1855a37319
commit
4083c34547
@ -7,8 +7,8 @@ module Flux
|
||||
using Juno
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
SGD
|
||||
export Chain, Dense, Seq, ChainSeq, RNN, LSTM,
|
||||
SGD, params
|
||||
|
||||
using NNlib
|
||||
export σ, relu, softmax
|
||||
|
@ -3,8 +3,30 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||
|
||||
# Sequences
|
||||
|
||||
struct Seq{T}
|
||||
data::Vector{T}
|
||||
struct Seq{T,A<:AbstractVector{T}}
|
||||
data::A
|
||||
end
|
||||
|
||||
Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs)
|
||||
|
||||
Seq(xs) = Seq(collect(xs))
|
||||
|
||||
Base.getindex(s::Seq, i) = s.data[i]
|
||||
|
||||
type ChainSeq
|
||||
layers::Vector{Any}
|
||||
ChainSeq(xs...) = new([xs...])
|
||||
end
|
||||
|
||||
Optimise.children(c::ChainSeq) = c.layers
|
||||
|
||||
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
|
||||
|
||||
function Base.show(io::IO, c::ChainSeq)
|
||||
print(io, "ChainSeq(")
|
||||
join(io, c.layers, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
# Stateful recurrence
|
||||
@ -79,6 +101,9 @@ end
|
||||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
Optimise.children(m::LSTMCell) =
|
||||
(m.forget, m.input, m.output, m.cell, m.h, m.c)
|
||||
|
||||
Base.show(io::IO, m::LSTMCell) =
|
||||
print(io, "LSTMCell(",
|
||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||
|
Loading…
Reference in New Issue
Block a user