avoid adjoint on abstract type

This commit is contained in:
Mike Innes 2019-08-19 14:39:09 +01:00
parent b8fabad337
commit 49044dff7c

View File

@ -286,15 +286,17 @@ end
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) for RNN in (CuRNN, CuGRU)
reserve, result = forwardTrain(desc(m), x, h) @eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
result, function (Δ) reserve, result = forwardTrain(desc(m), x, h)
y, ho = result result, function (Δ)
dy, dho = Δ y, ho = result
h_ = hBatch(x, h) dy, dho = Δ
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) h_ = hBatch(x, h)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
end
end end
end end