diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 02f78a96..4cc7313d 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -286,15 +286,17 @@ end (m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) - reserve, result = forwardTrain(desc(m), x, h) - result, function (Δ) - y, ho = result - dy, dho = Δ - h_ = hBatch(x, h) - dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) - nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) +for RNN in (CuRNN, CuGRU) + @eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), x, h) + result, function (Δ) + y, ho = result + dy, dho = Δ + h_ = hBatch(x, h) + dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) + nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) + end end end