avoid adjoint on abstract type
This commit is contained in:
parent
b8fabad337
commit
49044dff7c
@ -286,15 +286,17 @@ end
|
|||||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
|
||||||
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
for RNN in (CuRNN, CuGRU)
|
||||||
reserve, result = forwardTrain(desc(m), x, h)
|
@eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
|
||||||
result, function (Δ)
|
reserve, result = forwardTrain(desc(m), x, h)
|
||||||
y, ho = result
|
result, function (Δ)
|
||||||
dy, dho = Δ
|
y, ho = result
|
||||||
h_ = hBatch(x, h)
|
dy, dho = Δ
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
h_ = hBatch(x, h)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
|
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user