From 0f1e7b55780723ec1ecc282ba98b506cbf1480fe Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 1 Feb 2018 20:57:39 +0000 Subject: [PATCH] update rnn structure --- src/layers/recurrent.jl | 105 +++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 51 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 7510266e..50adfc86 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,7 +1,6 @@ -# TODO: broadcasting cat -combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2))) -combine(x::AbstractVector, h::AbstractVector) = vcat(x, h) -combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h) +gate(h, n) = (1:h) + h*(n-1) +gate(x::AbstractVector, h, n) = x[gate(h,n)] +gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] # Stateful recurrence @@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN -struct RNNCell{D,V} - d::D +struct RNNCell{F,A,V} + σ::F + Wi::A + Wh::A + b::V h::V end -RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) = - RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out))) +RNNCell(in::Integer, out::Integer, σ = tanh; + init = glorot_uniform) = + RNNCell(σ, param(init(out, in)), param(init(out, out)), + param(zeros(out)), param(initn(out))) function (m::RNNCell)(h, x) - h = m.d(combine(x, h)) + σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b + h = σ.(Wi*x .+ Wh*h .+ b) return h, h end @@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h treelike(RNNCell) -function Base.show(io::IO, m::RNNCell) - print(io, "RNNCell(", m.d, ")") +function Base.show(io::IO, l::RNNCell) + print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") end """ @@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) # LSTM -struct LSTMCell{D1,D2,V} - forget::D1 - input::D1 - output::D1 - cell::D2 - h::V; c::V +struct LSTMCell{A,V} + Wi::A + Wh::A + b::V + h::V + c::V end -function LSTMCell(in, out; initW = glorot_uniform, initb = zeros) - cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]..., - Dense(in+out, out, tanh, initW = initW, initb = initb), - param(initW(out)), param(initW(out))) - cell.forget.b.data .= 1 +function LSTMCell(in::Integer, out::Integer; + init = glorot_uniform) + cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)), + param(initn(out)), param(initn(out))) + cell.b.data[gate(out, 2)] = 1 return cell end function (m::LSTMCell)(h_, x) - h, c = h_ - x′ = combine(x, h) - forget, input, output, cell = - m.forget(x′), m.input(x′), m.output(x′), m.cell(x′) + h, c = h_ # TODO: nicer syntax on 0.7 + b, o = m.b, length(h) + g = m.Wi*x .+ m.Wh*h .+ b + input = σ.(gate(g, o, 1)) + forget = σ.(gate(g, o, 2)) + cell = tanh.(gate(g, o, 3)) + output = σ.(gate(g, o, 4)) c = forget .* c .+ input .* cell - h = output .* tanh.(c) - return (h, c), h + h′ = output .* tanh.(c) + return (h′, c), h′ end hidden(m::LSTMCell) = (m.h, m.c) treelike(LSTMCell) -Base.show(io::IO, m::LSTMCell) = - print(io, "LSTMCell(", - size(m.forget.W, 2) - size(m.forget.W, 1), ", ", - size(m.forget.W, 1), ')') +Base.show(io::IO, l::LSTMCell) = + print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")") """ LSTM(in::Integer, out::Integer, σ = tanh) @@ -153,26 +161,23 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU -struct GRUCell{D1,D2,V} - update::D1 - reset::D1 - candidate::D2 +struct GRUCell{A,V} + Wi::A + Wh::A + b::V h::V end -function GRUCell(in, out) - cell = GRUCell(Dense(in+out, out, σ), - Dense(in+out, out, σ), - Dense(in+out, out, tanh), - param(initn(out))) - return cell -end +GRUCell(in, out; init = glorot_uniform) = + GRUCell(param(init(out*3, in)), param(init(out*3, out)), + param(zeros(out*3)), param(initn(out))) function (m::GRUCell)(h, x) - x′ = combine(x, h) - z = m.update(x′) - r = m.reset(x′) - h̃ = m.candidate(combine(r.*h, x)) + b, o = m.b, length(h) + gx, gh = m.Wi*x, m.Wh*h + r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) + z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) + h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) h′ = (1.-z).*h̃ .+ z.*h return h′, h′ end @@ -181,10 +186,8 @@ hidden(m::GRUCell) = m.h treelike(GRUCell) -Base.show(io::IO, m::GRUCell) = - print(io, "GRUCell(", - size(m.update.W, 2) - size(m.update.W, 1), ", ", - size(m.update.W, 1), ')') +Base.show(io::IO, l::GRUCell) = + print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")") """ GRU(in::Integer, out::Integer, σ = tanh)