diff --git a/src/Flux.jl b/src/Flux.jl index a48b7b90..ec5238ff 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,8 +7,8 @@ module Flux using Juno using Lazy: @forward -export Chain, Dense, RNN, LSTM, - SGD +export Chain, Dense, Seq, ChainSeq, RNN, LSTM, + SGD, params using NNlib export σ, relu, softmax diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 01848379..f679ae37 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -3,8 +3,30 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2))) # Sequences -struct Seq{T} - data::Vector{T} +struct Seq{T,A<:AbstractVector{T}} + data::A +end + +Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs) + +Seq(xs) = Seq(collect(xs)) + +Base.getindex(s::Seq, i) = s.data[i] + +type ChainSeq + layers::Vector{Any} + ChainSeq(xs...) = new([xs...]) +end + +Optimise.children(c::ChainSeq) = c.layers + +(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers) +(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data]) + +function Base.show(io::IO, c::ChainSeq) + print(io, "ChainSeq(") + join(io, c.layers, ", ") + print(io, ")") end # Stateful recurrence @@ -79,6 +101,9 @@ end hidden(m::LSTMCell) = (m.h, m.c) +Optimise.children(m::LSTMCell) = + (m.forget, m.input, m.output, m.cell, m.h, m.c) + Base.show(io::IO, m::LSTMCell) = print(io, "LSTMCell(", size(m.forget.W, 2) - size(m.forget.W, 1), ", ",