diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index c097d6fe..b57e81f8 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -306,7 +306,7 @@ end h_ = hBatch(x, data(h)) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) - nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db)) + nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) end end @@ -320,7 +320,7 @@ end dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) nobacksies(:RNN, - (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), + (dx, unbroadcast(h, dh), unbroadcast(c, dc), transpose(dWi), transpose(dWh), db)) end end