batching in rnns
This commit is contained in:
parent
830d7fa611
commit
ec02f1fabd
@ -1,3 +1,6 @@
|
||||
# TODO: broadcasting cat
|
||||
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||
|
||||
# Stateful recurrence
|
||||
|
||||
mutable struct Recur{T}
|
||||
@ -24,7 +27,7 @@ RNNCell(in::Integer, out::Integer, init = initn) =
|
||||
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
h = m.d([x; h])
|
||||
h = m.d(combine(x, h))
|
||||
return h, h
|
||||
end
|
||||
|
||||
@ -51,7 +54,7 @@ LSTMCell(in, out; init = initn) =
|
||||
|
||||
function (m::LSTMCell)(h_, x)
|
||||
h, c = h_
|
||||
x′ = [x; h]
|
||||
x′ = combine(x, h)
|
||||
forget, input, output, cell =
|
||||
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
||||
c = forget .* c .+ input .* cell
|
||||
|
Loading…
Reference in New Issue
Block a user