tuple support
This commit is contained in:
parent
b348b20452
commit
368b1f53b4
|
@ -299,6 +299,7 @@ unbroadcast(x::AbstractArray, Δ) =
|
|||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
coerce_cuda(x::Union{CuArray,Nothing}) = x
|
||||
coerce_cuda(x::Tuple) = coerce_cuda.(x)
|
||||
|
||||
coerce_cuda(x) = x .+ CuArrays.fill(0)
|
||||
|
||||
|
@ -315,7 +316,7 @@ for RNN in (CuRNN, CuGRU)
|
|||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
||||
(ho, y), function (Δ)
|
||||
dho, dy = coerce_cuda.(Δ)
|
||||
dho, dy = coerce_cuda(Δ)
|
||||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
|
@ -328,7 +329,7 @@ end
|
|||
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
|
||||
((ho, co), y), function (Δ)
|
||||
dhc, dy = coerce_cuda.(Δ)
|
||||
dhc, dy = coerce_cuda(Δ)
|
||||
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||
h_ = hBatch(x, h)
|
||||
c_ = hBatch(x, c)
|
||||
|
|
Loading…
Reference in New Issue