avoid adjoint on abstract type

This commit is contained in:
Mike Innes 2019-08-19 14:39:09 +01:00
parent b8fabad337
commit 49044dff7c

View File

@ -286,15 +286,17 @@ end
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), x, h)
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
for RNN in (CuRNN, CuGRU)
@eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), x, h)
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
end
end
end