update rnn structure
This commit is contained in:
parent
af3ccf85ff
commit
0f1e7b5578
@ -1,7 +1,6 @@
|
|||||||
# TODO: broadcasting cat
|
gate(h, n) = (1:h) + h*(n-1)
|
||||||
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
|
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||||
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
|
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||||
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
|
|
||||||
|
|
||||||
# Stateful recurrence
|
# Stateful recurrence
|
||||||
|
|
||||||
@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs)))
|
|||||||
|
|
||||||
# Vanilla RNN
|
# Vanilla RNN
|
||||||
|
|
||||||
struct RNNCell{D,V}
|
struct RNNCell{F,A,V}
|
||||||
d::D
|
σ::F
|
||||||
|
Wi::A
|
||||||
|
Wh::A
|
||||||
|
b::V
|
||||||
h::V
|
h::V
|
||||||
end
|
end
|
||||||
|
|
||||||
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||||
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
init = glorot_uniform) =
|
||||||
|
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||||
|
param(zeros(out)), param(initn(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
h = m.d(combine(x, h))
|
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||||
|
h = σ.(Wi*x .+ Wh*h .+ b)
|
||||||
return h, h
|
return h, h
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h
|
|||||||
|
|
||||||
treelike(RNNCell)
|
treelike(RNNCell)
|
||||||
|
|
||||||
function Base.show(io::IO, m::RNNCell)
|
function Base.show(io::IO, l::RNNCell)
|
||||||
print(io, "RNNCell(", m.d, ")")
|
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||||
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
|||||||
|
|
||||||
# LSTM
|
# LSTM
|
||||||
|
|
||||||
struct LSTMCell{D1,D2,V}
|
struct LSTMCell{A,V}
|
||||||
forget::D1
|
Wi::A
|
||||||
input::D1
|
Wh::A
|
||||||
output::D1
|
b::V
|
||||||
cell::D2
|
h::V
|
||||||
h::V; c::V
|
c::V
|
||||||
end
|
end
|
||||||
|
|
||||||
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
function LSTMCell(in::Integer, out::Integer;
|
||||||
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
init = glorot_uniform)
|
||||||
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||||
param(initW(out)), param(initW(out)))
|
param(initn(out)), param(initn(out)))
|
||||||
cell.forget.b.data .= 1
|
cell.b.data[gate(out, 2)] = 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::LSTMCell)(h_, x)
|
function (m::LSTMCell)(h_, x)
|
||||||
h, c = h_
|
h, c = h_ # TODO: nicer syntax on 0.7
|
||||||
x′ = combine(x, h)
|
b, o = m.b, length(h)
|
||||||
forget, input, output, cell =
|
g = m.Wi*x .+ m.Wh*h .+ b
|
||||||
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
input = σ.(gate(g, o, 1))
|
||||||
|
forget = σ.(gate(g, o, 2))
|
||||||
|
cell = tanh.(gate(g, o, 3))
|
||||||
|
output = σ.(gate(g, o, 4))
|
||||||
c = forget .* c .+ input .* cell
|
c = forget .* c .+ input .* cell
|
||||||
h = output .* tanh.(c)
|
h′ = output .* tanh.(c)
|
||||||
return (h, c), h
|
return (h′, c), h′
|
||||||
end
|
end
|
||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
treelike(LSTMCell)
|
treelike(LSTMCell)
|
||||||
|
|
||||||
Base.show(io::IO, m::LSTMCell) =
|
Base.show(io::IO, l::LSTMCell) =
|
||||||
print(io, "LSTMCell(",
|
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
|
||||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
|
||||||
size(m.forget.W, 1), ')')
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
LSTM(in::Integer, out::Integer, σ = tanh)
|
LSTM(in::Integer, out::Integer, σ = tanh)
|
||||||
@ -153,26 +161,23 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
|||||||
|
|
||||||
# GRU
|
# GRU
|
||||||
|
|
||||||
struct GRUCell{D1,D2,V}
|
struct GRUCell{A,V}
|
||||||
update::D1
|
Wi::A
|
||||||
reset::D1
|
Wh::A
|
||||||
candidate::D2
|
b::V
|
||||||
h::V
|
h::V
|
||||||
end
|
end
|
||||||
|
|
||||||
function GRUCell(in, out)
|
GRUCell(in, out; init = glorot_uniform) =
|
||||||
cell = GRUCell(Dense(in+out, out, σ),
|
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||||
Dense(in+out, out, σ),
|
param(zeros(out*3)), param(initn(out)))
|
||||||
Dense(in+out, out, tanh),
|
|
||||||
param(initn(out)))
|
|
||||||
return cell
|
|
||||||
end
|
|
||||||
|
|
||||||
function (m::GRUCell)(h, x)
|
function (m::GRUCell)(h, x)
|
||||||
x′ = combine(x, h)
|
b, o = m.b, length(h)
|
||||||
z = m.update(x′)
|
gx, gh = m.Wi*x, m.Wh*h
|
||||||
r = m.reset(x′)
|
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||||
h̃ = m.candidate(combine(r.*h, x))
|
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||||
|
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
||||||
h′ = (1.-z).*h̃ .+ z.*h
|
h′ = (1.-z).*h̃ .+ z.*h
|
||||||
return h′, h′
|
return h′, h′
|
||||||
end
|
end
|
||||||
@ -181,10 +186,8 @@ hidden(m::GRUCell) = m.h
|
|||||||
|
|
||||||
treelike(GRUCell)
|
treelike(GRUCell)
|
||||||
|
|
||||||
Base.show(io::IO, m::GRUCell) =
|
Base.show(io::IO, l::GRUCell) =
|
||||||
print(io, "GRUCell(",
|
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
|
||||||
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
|
||||||
size(m.update.W, 1), ')')
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
GRU(in::Integer, out::Integer, σ = tanh)
|
GRU(in::Integer, out::Integer, σ = tanh)
|
||||||
|
Loading…
Reference in New Issue
Block a user