From c6740c5cdd735e91869cf7615e711cfa47679f8f Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 14:14:24 +0100 Subject: [PATCH] fix unbroadcast --- src/cuda/cudnn.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index f033595a..61609b0d 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -328,7 +328,7 @@ end h_ = hBatch(x, data(h)) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) - nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db)) + nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) end end @@ -342,7 +342,7 @@ end dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) nobacksies(:RNN, - (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), + (dx, unbroadcast(h, dh), unbroadcast(c, dc), transpose(dWi), transpose(dWh), db)) end end