seq stuff
This commit is contained in:
parent
1855a37319
commit
4083c34547
@ -7,8 +7,8 @@ module Flux
|
|||||||
using Juno
|
using Juno
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, Seq, ChainSeq, RNN, LSTM,
|
||||||
SGD
|
SGD, params
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, relu, softmax
|
export σ, relu, softmax
|
||||||
|
@ -3,8 +3,30 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
|||||||
|
|
||||||
# Sequences
|
# Sequences
|
||||||
|
|
||||||
struct Seq{T}
|
struct Seq{T,A<:AbstractVector{T}}
|
||||||
data::Vector{T}
|
data::A
|
||||||
|
end
|
||||||
|
|
||||||
|
Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs)
|
||||||
|
|
||||||
|
Seq(xs) = Seq(collect(xs))
|
||||||
|
|
||||||
|
Base.getindex(s::Seq, i) = s.data[i]
|
||||||
|
|
||||||
|
type ChainSeq
|
||||||
|
layers::Vector{Any}
|
||||||
|
ChainSeq(xs...) = new([xs...])
|
||||||
|
end
|
||||||
|
|
||||||
|
Optimise.children(c::ChainSeq) = c.layers
|
||||||
|
|
||||||
|
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||||
|
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
|
||||||
|
|
||||||
|
function Base.show(io::IO, c::ChainSeq)
|
||||||
|
print(io, "ChainSeq(")
|
||||||
|
join(io, c.layers, ", ")
|
||||||
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
# Stateful recurrence
|
# Stateful recurrence
|
||||||
@ -79,6 +101,9 @@ end
|
|||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
|
Optimise.children(m::LSTMCell) =
|
||||||
|
(m.forget, m.input, m.output, m.cell, m.h, m.c)
|
||||||
|
|
||||||
Base.show(io::IO, m::LSTMCell) =
|
Base.show(io::IO, m::LSTMCell) =
|
||||||
print(io, "LSTMCell(",
|
print(io, "LSTMCell(",
|
||||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||||
|
Loading…
Reference in New Issue
Block a user