avoid adjoint on abstract type
This commit is contained in:
parent
b8fabad337
commit
49044dff7c
@ -286,15 +286,17 @@ end
|
||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), x, h)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
for RNN in (CuRNN, CuGRU)
|
||||
@eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), x, h)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user