batching in rnns

This commit is contained in:
Mike J Innes 2017-09-05 02:29:31 -04:00
parent 830d7fa611
commit ec02f1fabd

View File

@ -1,3 +1,6 @@
# TODO: broadcasting cat
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Stateful recurrence
mutable struct Recur{T}
@ -24,7 +27,7 @@ RNNCell(in::Integer, out::Integer, init = initn) =
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
function (m::RNNCell)(h, x)
h = m.d([x; h])
h = m.d(combine(x, h))
return h, h
end
@ -51,7 +54,7 @@ LSTMCell(in, out; init = initn) =
function (m::LSTMCell)(h_, x)
h, c = h_
x = [x; h]
x = combine(x, h)
forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x)
c = forget .* c .+ input .* cell