update rnn structure

This commit is contained in:
Mike J Innes 2018-02-01 20:57:39 +00:00
parent af3ccf85ff
commit 0f1e7b5578

View File

@ -1,7 +1,6 @@
# TODO: broadcasting cat gate(h, n) = (1:h) + h*(n-1)
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2))) gate(x::AbstractVector, h, n) = x[gate(h,n)]
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h) gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
# Stateful recurrence # Stateful recurrence
@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs)))
# Vanilla RNN # Vanilla RNN
struct RNNCell{D,V} struct RNNCell{F,A,V}
d::D σ::F
Wi::A
Wh::A
b::V
h::V h::V
end end
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) = RNNCell(in::Integer, out::Integer, σ = tanh;
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out))) init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(initn(out)))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
h = m.d(combine(x, h)) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
return h, h return h, h
end end
@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h
treelike(RNNCell) treelike(RNNCell)
function Base.show(io::IO, m::RNNCell) function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", m.d, ")") print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end end
""" """
@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM # LSTM
struct LSTMCell{D1,D2,V} struct LSTMCell{A,V}
forget::D1 Wi::A
input::D1 Wh::A
output::D1 b::V
cell::D2 h::V
h::V; c::V c::V
end end
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros) function LSTMCell(in::Integer, out::Integer;
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]..., init = glorot_uniform)
Dense(in+out, out, tanh, initW = initW, initb = initb), cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(initW(out)), param(initW(out))) param(initn(out)), param(initn(out)))
cell.forget.b.data .= 1 cell.b.data[gate(out, 2)] = 1
return cell return cell
end end
function (m::LSTMCell)(h_, x) function (m::LSTMCell)(h_, x)
h, c = h_ h, c = h_ # TODO: nicer syntax on 0.7
x = combine(x, h) b, o = m.b, length(h)
forget, input, output, cell = g = m.Wi*x .+ m.Wh*h .+ b
m.forget(x), m.input(x), m.output(x), m.cell(x) input = σ.(gate(g, o, 1))
forget = σ.(gate(g, o, 2))
cell = tanh.(gate(g, o, 3))
output = σ.(gate(g, o, 4))
c = forget .* c .+ input .* cell c = forget .* c .+ input .* cell
h = output .* tanh.(c) h = output .* tanh.(c)
return (h, c), h return (h, c), h
end end
hidden(m::LSTMCell) = (m.h, m.c) hidden(m::LSTMCell) = (m.h, m.c)
treelike(LSTMCell) treelike(LSTMCell)
Base.show(io::IO, m::LSTMCell) = Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
""" """
LSTM(in::Integer, out::Integer, σ = tanh) LSTM(in::Integer, out::Integer, σ = tanh)
@ -153,26 +161,23 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU # GRU
struct GRUCell{D1,D2,V} struct GRUCell{A,V}
update::D1 Wi::A
reset::D1 Wh::A
candidate::D2 b::V
h::V h::V
end end
function GRUCell(in, out) GRUCell(in, out; init = glorot_uniform) =
cell = GRUCell(Dense(in+out, out, σ), GRUCell(param(init(out*3, in)), param(init(out*3, out)),
Dense(in+out, out, σ), param(zeros(out*3)), param(initn(out)))
Dense(in+out, out, tanh),
param(initn(out)))
return cell
end
function (m::GRUCell)(h, x) function (m::GRUCell)(h, x)
x = combine(x, h) b, o = m.b, length(h)
z = m.update(x) gx, gh = m.Wi*x, m.Wh*h
r = m.reset(x) r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
= m.candidate(combine(r.*h, x)) z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h = (1.-z).* .+ z.*h h = (1.-z).* .+ z.*h
return h, h return h, h
end end
@ -181,10 +186,8 @@ hidden(m::GRUCell) = m.h
treelike(GRUCell) treelike(GRUCell)
Base.show(io::IO, m::GRUCell) = Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
""" """
GRU(in::Integer, out::Integer, σ = tanh) GRU(in::Integer, out::Integer, σ = tanh)