batching in rnns

This commit is contained in:
Mike J Innes 2017-09-05 02:29:31 -04:00
parent 830d7fa611
commit ec02f1fabd

View File

@ -1,3 +1,6 @@
# TODO: broadcasting cat
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Stateful recurrence # Stateful recurrence
mutable struct Recur{T} mutable struct Recur{T}
@ -24,7 +27,7 @@ RNNCell(in::Integer, out::Integer, init = initn) =
RNNCell(Dense(in+out, out, init = initn), track(initn(out))) RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
h = m.d([x; h]) h = m.d(combine(x, h))
return h, h return h, h
end end
@ -51,7 +54,7 @@ LSTMCell(in, out; init = initn) =
function (m::LSTMCell)(h_, x) function (m::LSTMCell)(h_, x)
h, c = h_ h, c = h_
x = [x; h] x = combine(x, h)
forget, input, output, cell = forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x) m.forget(x), m.input(x), m.output(x), m.cell(x)
c = forget .* c .+ input .* cell c = forget .* c .+ input .* cell