batching in rnns
This commit is contained in:
parent
830d7fa611
commit
ec02f1fabd
@ -1,3 +1,6 @@
|
|||||||
|
# TODO: broadcasting cat
|
||||||
|
combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||||
|
|
||||||
# Stateful recurrence
|
# Stateful recurrence
|
||||||
|
|
||||||
mutable struct Recur{T}
|
mutable struct Recur{T}
|
||||||
@ -24,7 +27,7 @@ RNNCell(in::Integer, out::Integer, init = initn) =
|
|||||||
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
h = m.d([x; h])
|
h = m.d(combine(x, h))
|
||||||
return h, h
|
return h, h
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -51,7 +54,7 @@ LSTMCell(in, out; init = initn) =
|
|||||||
|
|
||||||
function (m::LSTMCell)(h_, x)
|
function (m::LSTMCell)(h_, x)
|
||||||
h, c = h_
|
h, c = h_
|
||||||
x′ = [x; h]
|
x′ = combine(x, h)
|
||||||
forget, input, output, cell =
|
forget, input, output, cell =
|
||||||
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
||||||
c = forget .* c .+ input .* cell
|
c = forget .* c .+ input .* cell
|
||||||
|
Loading…
Reference in New Issue
Block a user