Flux.jl/src/layers/recurrent.jl

94 lines
1.9 KiB
Julia
Raw Normal View History

2017-09-05 06:29:31 +00:00
# TODO: broadcasting cat
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
2017-09-03 06:12:44 +00:00
# Stateful recurrence
mutable struct Recur{T}
cell::T
state
end
Recur(m) = Recur(m, hidden(m))
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end
2017-09-07 04:05:02 +00:00
Optimise.children(m::Recur) = (m.cell,)
2017-09-05 23:25:34 +00:00
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
2017-09-07 04:05:02 +00:00
_truncate(x::AbstractArray) = x
_truncate(x::TrackedArray) = x.data
_truncate(x::Tuple) = _truncate.(x)
truncate!(m) = foreach(truncate!, Optimise.children(m))
truncate!(m::Recur) = (m.state = _truncate(m.state))
2017-09-06 18:03:25 +00:00
2017-09-03 06:12:44 +00:00
# Vanilla RNN
struct RNNCell{D,V}
d::D
h::V
end
RNNCell(in::Integer, out::Integer, σ=identity; init = initn) =
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
2017-09-03 06:12:44 +00:00
function (m::RNNCell)(h, x)
2017-09-05 06:29:31 +00:00
h = m.d(combine(x, h))
2017-09-03 06:12:44 +00:00
return h, h
end
hidden(m::RNNCell) = m.h
2017-09-07 04:05:02 +00:00
Optimise.children(m::RNNCell) = (m.d, m.h)
2017-09-03 06:12:44 +00:00
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
2017-09-05 23:25:34 +00:00
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
2017-09-03 06:12:44 +00:00
# LSTM
2017-09-03 06:24:47 +00:00
struct LSTMCell{D1,D2,V}
forget::D1
input::D1
output::D1
cell::D2
h::V; c::V
2017-09-03 06:12:44 +00:00
end
2017-09-05 06:42:32 +00:00
function LSTMCell(in, out; init = initn)
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
Dense(in+out, out, tanh, init = initn),
2017-09-07 19:13:04 +00:00
param(initn(out)), param(initn(out)))
2017-09-07 03:09:32 +00:00
cell.forget.b.data .= 1
2017-09-05 06:42:32 +00:00
return cell
end
2017-09-03 06:12:44 +00:00
function (m::LSTMCell)(h_, x)
h, c = h_
2017-09-05 06:29:31 +00:00
x = combine(x, h)
2017-09-03 06:24:47 +00:00
forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x)
c = forget .* c .+ input .* cell
2017-09-03 06:12:44 +00:00
h = output .* tanh.(c)
return (h, c), h
end
2017-09-03 06:24:47 +00:00
hidden(m::LSTMCell) = (m.h, m.c)
2017-09-05 23:25:34 +00:00
2017-09-06 22:59:07 +00:00
Optimise.children(m::LSTMCell) =
(m.forget, m.input, m.output, m.cell, m.h, m.c)
2017-09-05 23:25:34 +00:00
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))