diff --git a/Manifest.toml b/Manifest.toml index 299a40b5..d10fc71b 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -46,9 +46,9 @@ version = "0.6.2" [[CUDAapi]] deps = ["Libdl", "Logging"] -git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091" +git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b" uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" -version = "1.1.0" +version = "1.2.0" [[CUDAdrv]] deps = ["CUDAapi", "Libdl", "Printf"] @@ -105,8 +105,8 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.0.0" [[CuArrays]] -deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "63b4a10d3a4f22ef215d0970483b18296717d1fb" +deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] +git-tree-sha1 = "4e638627673078c58b6e6bb789937822d83350ff" repo-rev = "tb/flux" repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" @@ -172,9 +172,9 @@ version = "0.10.3" [[GPUArrays]] deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"] -git-tree-sha1 = "b5009ac44b141ded5e6f04c4db83807970f56e91" +git-tree-sha1 = "77e27264276fe97a7e7fb928bf8999a145abc018" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "1.0.2" +version = "1.0.3" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 86422d03..fb454729 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -56,7 +56,7 @@ unbroadcast(x::AbstractArray, Δ) = coerce_cuda(x::Union{CuArray,Nothing}) = x coerce_cuda(x::Tuple) = coerce_cuda.(x) -coerce_cuda(x) = x .+ CuArrays.fill(0) +coerce_cuda(x::AbstractArray) = x .+ CuArrays.fill(0) function struct_grad!(cx::Zygote.Context, x, x̄) for f in fieldnames(typeof(x)) @@ -69,28 +69,23 @@ end for RNN in (CuRNN, CuGRU) @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h) + (y, ho), back = CUDNN.pullback(desc(m), x, h) (ho, y), function (Δ) - dho, dy = coerce_cuda(Δ) - h_ = CUDNN.hBatch(x, h) - dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve) - (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing)) - (dm, unbroadcast(h, dh), dx) + dho, dy = coerce_cuda(Δ) # Support FillArrays etc. + m̄ = back(dy, dho) + dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) + (dm, unbroadcast(h, m̄.h), m̄.x) end end end @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c) + (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) ((ho, co), y), function (Δ) - dhc, dy = coerce_cuda(Δ) + dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. dho, dco = dhc === nothing ? (nothing, nothing) : dhc - h_ = CUDNN.hBatch(x, h) - c_ = CUDNN.hBatch(x, c) - dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) - (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) - dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing)) - (dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx) + m̄ = back(dy, dho, dco) + dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) + (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) end end