update rnn structure

This commit is contained in:
Mike J Innes 2018-02-01 20:57:39 +00:00
parent af3ccf85ff
commit 0f1e7b5578

View File

@ -1,7 +1,6 @@
# TODO: broadcasting cat
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
gate(h, n) = (1:h) + h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
# Stateful recurrence
@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs)))
# Vanilla RNN
struct RNNCell{D,V}
d::D
struct RNNCell{F,A,V}
σ::F
Wi::A
Wh::A
b::V
h::V
end
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(initn(out)))
function (m::RNNCell)(h, x)
h = m.d(combine(x, h))
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
return h, h
end
@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h
treelike(RNNCell)
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM
struct LSTMCell{D1,D2,V}
forget::D1
input::D1
output::D1
cell::D2
h::V; c::V
struct LSTMCell{A,V}
Wi::A
Wh::A
b::V
h::V
c::V
end
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
Dense(in+out, out, tanh, initW = initW, initb = initb),
param(initW(out)), param(initW(out)))
cell.forget.b.data .= 1
function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(initn(out)), param(initn(out)))
cell.b.data[gate(out, 2)] = 1
return cell
end
function (m::LSTMCell)(h_, x)
h, c = h_
x = combine(x, h)
forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x)
h, c = h_ # TODO: nicer syntax on 0.7
b, o = m.b, length(h)
g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1))
forget = σ.(gate(g, o, 2))
cell = tanh.(gate(g, o, 3))
output = σ.(gate(g, o, 4))
c = forget .* c .+ input .* cell
h = output .* tanh.(c)
return (h, c), h
h = output .* tanh.(c)
return (h, c), h
end
hidden(m::LSTMCell) = (m.h, m.c)
treelike(LSTMCell)
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
"""
LSTM(in::Integer, out::Integer, σ = tanh)
@ -153,26 +161,23 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
struct GRUCell{A,V}
Wi::A
Wh::A
b::V
h::V
end
function GRUCell(in, out)
cell = GRUCell(Dense(in+out, out, σ),
Dense(in+out, out, σ),
Dense(in+out, out, tanh),
param(initn(out)))
return cell
end
GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zeros(out*3)), param(initn(out)))
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
b, o = m.b, length(h)
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h = (1.-z).* .+ z.*h
return h, h
end
@ -181,10 +186,8 @@ hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
"""
GRU(in::Integer, out::Integer, σ = tanh)