avoid adjoint on abstract type

This commit is contained in:
Mike Innes 2019-08-19 14:39:09 +01:00
parent b8fabad337
commit 49044dff7c

View File

@ -286,7 +286,8 @@ end
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
for RNN in (CuRNN, CuGRU)
@eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), x, h)
result, function (Δ)
y, ho = result
@ -296,6 +297,7 @@ end
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
end
end
end
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)