fix unbroadcast

This commit is contained in:
Mike Innes 2018-10-05 14:14:24 +01:00
parent 325d2ce212
commit c6740c5cdd

View File

@ -328,7 +328,7 @@ end
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
end
end
@ -342,7 +342,7 @@ end
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
transpose(dWi), transpose(dWh), db))
end
end