2017-09-05 06:29:31 +00:00
|
|
|
|
# TODO: broadcasting cat
|
|
|
|
|
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
|
|
|
|
|
2017-09-06 18:03:25 +00:00
|
|
|
|
# Sequences
|
|
|
|
|
|
2017-09-06 22:59:07 +00:00
|
|
|
|
struct Seq{T,A<:AbstractVector{T}}
|
|
|
|
|
data::A
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
Seq(xs::AbstractVector{T}) where T = Seq{T,typeof(xs)}(xs)
|
|
|
|
|
|
|
|
|
|
Seq(xs) = Seq(collect(xs))
|
|
|
|
|
|
|
|
|
|
Base.getindex(s::Seq, i) = s.data[i]
|
|
|
|
|
|
|
|
|
|
type ChainSeq
|
|
|
|
|
layers::Vector{Any}
|
|
|
|
|
ChainSeq(xs...) = new([xs...])
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
Optimise.children(c::ChainSeq) = c.layers
|
|
|
|
|
|
|
|
|
|
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
|
|
|
|
|
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
|
|
|
|
|
|
|
|
|
|
function Base.show(io::IO, c::ChainSeq)
|
|
|
|
|
print(io, "ChainSeq(")
|
|
|
|
|
join(io, c.layers, ", ")
|
|
|
|
|
print(io, ")")
|
2017-09-06 18:03:25 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# Stateful recurrence
|
|
|
|
|
|
|
|
|
|
mutable struct Recur{T}
|
|
|
|
|
cell::T
|
|
|
|
|
state
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
Recur(m) = Recur(m, hidden(m))
|
|
|
|
|
|
|
|
|
|
function (m::Recur)(xs...)
|
|
|
|
|
h, y = m.cell(m.state, xs...)
|
|
|
|
|
m.state = h
|
|
|
|
|
return y
|
|
|
|
|
end
|
|
|
|
|
|
2017-09-05 23:25:34 +00:00
|
|
|
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
|
|
|
|
|
2017-09-06 18:03:25 +00:00
|
|
|
|
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# Vanilla RNN
|
|
|
|
|
|
|
|
|
|
struct RNNCell{D,V}
|
|
|
|
|
d::D
|
|
|
|
|
h::V
|
|
|
|
|
end
|
|
|
|
|
|
2017-09-05 23:25:34 +00:00
|
|
|
|
RNNCell(in::Integer, out::Integer; init = initn) =
|
2017-09-03 06:12:44 +00:00
|
|
|
|
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
|
|
|
|
|
|
|
|
|
function (m::RNNCell)(h, x)
|
2017-09-05 06:29:31 +00:00
|
|
|
|
h = m.d(combine(x, h))
|
2017-09-03 06:12:44 +00:00
|
|
|
|
return h, h
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
hidden(m::RNNCell) = m.h
|
|
|
|
|
|
|
|
|
|
function Base.show(io::IO, m::RNNCell)
|
|
|
|
|
print(io, "RNNCell(", m.d, ")")
|
|
|
|
|
end
|
|
|
|
|
|
2017-09-05 23:25:34 +00:00
|
|
|
|
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
|
# LSTM
|
|
|
|
|
|
2017-09-03 06:24:47 +00:00
|
|
|
|
struct LSTMCell{D1,D2,V}
|
|
|
|
|
forget::D1
|
|
|
|
|
input::D1
|
|
|
|
|
output::D1
|
|
|
|
|
cell::D2
|
|
|
|
|
h::V; c::V
|
2017-09-03 06:12:44 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-09-05 06:42:32 +00:00
|
|
|
|
function LSTMCell(in, out; init = initn)
|
|
|
|
|
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
|
|
|
|
Dense(in+out, out, tanh, init = initn),
|
|
|
|
|
track(initn(out)), track(initn(out)))
|
|
|
|
|
cell.forget.b.x .= 1
|
|
|
|
|
return cell
|
|
|
|
|
end
|
2017-09-03 06:12:44 +00:00
|
|
|
|
|
|
|
|
|
function (m::LSTMCell)(h_, x)
|
|
|
|
|
h, c = h_
|
2017-09-05 06:29:31 +00:00
|
|
|
|
x′ = combine(x, h)
|
2017-09-03 06:24:47 +00:00
|
|
|
|
forget, input, output, cell =
|
|
|
|
|
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
|
|
|
|
c = forget .* c .+ input .* cell
|
2017-09-03 06:12:44 +00:00
|
|
|
|
h = output .* tanh.(c)
|
|
|
|
|
return (h, c), h
|
|
|
|
|
end
|
|
|
|
|
|
2017-09-03 06:24:47 +00:00
|
|
|
|
hidden(m::LSTMCell) = (m.h, m.c)
|
2017-09-05 23:25:34 +00:00
|
|
|
|
|
2017-09-06 22:59:07 +00:00
|
|
|
|
Optimise.children(m::LSTMCell) =
|
|
|
|
|
(m.forget, m.input, m.output, m.cell, m.h, m.c)
|
|
|
|
|
|
2017-09-05 23:25:34 +00:00
|
|
|
|
Base.show(io::IO, m::LSTMCell) =
|
|
|
|
|
print(io, "LSTMCell(",
|
|
|
|
|
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
|
|
|
|
size(m.forget.W, 1), ')')
|
|
|
|
|
|
|
|
|
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|