Flux.jl/src/layers/recurrent.jl
2017-09-12 13:12:25 +01:00

94 lines
1.9 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# TODO: broadcasting cat
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Stateful recurrence
mutable struct Recur{T}
cell::T
state
end
Recur(m) = Recur(m, hidden(m))
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end
Optimise.children(m::Recur) = (m.cell,)
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
_truncate(x::AbstractArray) = x
_truncate(x::TrackedArray) = x.data
_truncate(x::Tuple) = _truncate.(x)
truncate!(m) = foreach(truncate!, Optimise.children(m))
truncate!(m::Recur) = (m.state = _truncate(m.state))
# Vanilla RNN
struct RNNCell{D,V}
d::D
h::V
end
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) =
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
function (m::RNNCell)(h, x)
h = m.d(combine(x, h))
return h, h
end
hidden(m::RNNCell) = m.h
Optimise.children(m::RNNCell) = (m.d, m.h)
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM
struct LSTMCell{D1,D2,V}
forget::D1
input::D1
output::D1
cell::D2
h::V; c::V
end
function LSTMCell(in, out; init = initn)
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
Dense(in+out, out, tanh, init = initn),
param(initn(out)), param(initn(out)))
cell.forget.b.data .= 1
return cell
end
function (m::LSTMCell)(h_, x)
h, c = h_
x = combine(x, h)
forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x)
c = forget .* c .+ input .* cell
h = output .* tanh.(c)
return (h, c), h
end
hidden(m::LSTMCell) = (m.h, m.c)
Optimise.children(m::LSTMCell) =
(m.forget, m.input, m.output, m.cell, m.h, m.c)
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))