update rnn structure
This commit is contained in:
parent
af3ccf85ff
commit
0f1e7b5578
@ -1,7 +1,6 @@
|
||||
# TODO: broadcasting cat
|
||||
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
|
||||
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
|
||||
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
|
||||
gate(h, n) = (1:h) + h*(n-1)
|
||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||
|
||||
# Stateful recurrence
|
||||
|
||||
@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs)))
|
||||
|
||||
# Vanilla RNN
|
||||
|
||||
struct RNNCell{D,V}
|
||||
d::D
|
||||
struct RNNCell{F,A,V}
|
||||
σ::F
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
h::V
|
||||
end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
||||
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zeros(out)), param(initn(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
h = m.d(combine(x, h))
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
h = σ.(Wi*x .+ Wh*h .+ b)
|
||||
return h, h
|
||||
end
|
||||
|
||||
@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h
|
||||
|
||||
treelike(RNNCell)
|
||||
|
||||
function Base.show(io::IO, m::RNNCell)
|
||||
print(io, "RNNCell(", m.d, ")")
|
||||
function Base.show(io::IO, l::RNNCell)
|
||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||
|
||||
# LSTM
|
||||
|
||||
struct LSTMCell{D1,D2,V}
|
||||
forget::D1
|
||||
input::D1
|
||||
output::D1
|
||||
cell::D2
|
||||
h::V; c::V
|
||||
struct LSTMCell{A,V}
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
h::V
|
||||
c::V
|
||||
end
|
||||
|
||||
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
||||
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
||||
param(initW(out)), param(initW(out)))
|
||||
cell.forget.b.data .= 1
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.b.data[gate(out, 2)] = 1
|
||||
return cell
|
||||
end
|
||||
|
||||
function (m::LSTMCell)(h_, x)
|
||||
h, c = h_
|
||||
x′ = combine(x, h)
|
||||
forget, input, output, cell =
|
||||
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
||||
h, c = h_ # TODO: nicer syntax on 0.7
|
||||
b, o = m.b, length(h)
|
||||
g = m.Wi*x .+ m.Wh*h .+ b
|
||||
input = σ.(gate(g, o, 1))
|
||||
forget = σ.(gate(g, o, 2))
|
||||
cell = tanh.(gate(g, o, 3))
|
||||
output = σ.(gate(g, o, 4))
|
||||
c = forget .* c .+ input .* cell
|
||||
h = output .* tanh.(c)
|
||||
return (h, c), h
|
||||
h′ = output .* tanh.(c)
|
||||
return (h′, c), h′
|
||||
end
|
||||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
treelike(LSTMCell)
|
||||
|
||||
Base.show(io::IO, m::LSTMCell) =
|
||||
print(io, "LSTMCell(",
|
||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||
size(m.forget.W, 1), ')')
|
||||
Base.show(io::IO, l::LSTMCell) =
|
||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
|
||||
|
||||
"""
|
||||
LSTM(in::Integer, out::Integer, σ = tanh)
|
||||
@ -153,26 +161,23 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
|
||||
# GRU
|
||||
|
||||
struct GRUCell{D1,D2,V}
|
||||
update::D1
|
||||
reset::D1
|
||||
candidate::D2
|
||||
struct GRUCell{A,V}
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
h::V
|
||||
end
|
||||
|
||||
function GRUCell(in, out)
|
||||
cell = GRUCell(Dense(in+out, out, σ),
|
||||
Dense(in+out, out, σ),
|
||||
Dense(in+out, out, tanh),
|
||||
param(initn(out)))
|
||||
return cell
|
||||
end
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(initn(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
x′ = combine(x, h)
|
||||
z = m.update(x′)
|
||||
r = m.reset(x′)
|
||||
h̃ = m.candidate(combine(r.*h, x))
|
||||
b, o = m.b, length(h)
|
||||
gx, gh = m.Wi*x, m.Wh*h
|
||||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
||||
h′ = (1.-z).*h̃ .+ z.*h
|
||||
return h′, h′
|
||||
end
|
||||
@ -181,10 +186,8 @@ hidden(m::GRUCell) = m.h
|
||||
|
||||
treelike(GRUCell)
|
||||
|
||||
Base.show(io::IO, m::GRUCell) =
|
||||
print(io, "GRUCell(",
|
||||
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
||||
size(m.update.W, 1), ')')
|
||||
Base.show(io::IO, l::GRUCell) =
|
||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
|
||||
|
||||
"""
|
||||
GRU(in::Integer, out::Integer, σ = tanh)
|
||||
|
Loading…
Reference in New Issue
Block a user