seq stuff

This commit is contained in:
Mike J Innes 2017-09-06 18:59:07 -04:00
parent 1855a37319
commit 4083c34547
2 changed files with 29 additions and 4 deletions

View File

@ -7,8 +7,8 @@ module Flux
using Juno
using Lazy: @forward
export Chain, Dense, RNN, LSTM,
SGD
export Chain, Dense, Seq, ChainSeq, RNN, LSTM,
SGD, params
using NNlib
export σ, relu, softmax

View File

@ -3,8 +3,30 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Sequences
struct Seq{T}
data::Vector{T}
struct Seq{T,A<:AbstractVector{T}}
data::A
end
Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs)
Seq(xs) = Seq(collect(xs))
Base.getindex(s::Seq, i) = s.data[i]
type ChainSeq
layers::Vector{Any}
ChainSeq(xs...) = new([xs...])
end
Optimise.children(c::ChainSeq) = c.layers
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
function Base.show(io::IO, c::ChainSeq)
print(io, "ChainSeq(")
join(io, c.layers, ", ")
print(io, ")")
end
# Stateful recurrence
@ -79,6 +101,9 @@ end
hidden(m::LSTMCell) = (m.h, m.c)
Optimise.children(m::LSTMCell) =
(m.forget, m.input, m.output, m.cell, m.h, m.c)
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",