move pullbacks to CuArrays
This commit is contained in:
parent
c5e56b7e04
commit
46bc8e5e64
|
@ -46,9 +46,9 @@ version = "0.6.2"
|
||||||
|
|
||||||
[[CUDAapi]]
|
[[CUDAapi]]
|
||||||
deps = ["Libdl", "Logging"]
|
deps = ["Libdl", "Logging"]
|
||||||
git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091"
|
git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b"
|
||||||
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||||
version = "1.1.0"
|
version = "1.2.0"
|
||||||
|
|
||||||
[[CUDAdrv]]
|
[[CUDAdrv]]
|
||||||
deps = ["CUDAapi", "Libdl", "Printf"]
|
deps = ["CUDAapi", "Libdl", "Printf"]
|
||||||
|
@ -105,8 +105,8 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||||
version = "4.0.0"
|
version = "4.0.0"
|
||||||
|
|
||||||
[[CuArrays]]
|
[[CuArrays]]
|
||||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||||
git-tree-sha1 = "63b4a10d3a4f22ef215d0970483b18296717d1fb"
|
git-tree-sha1 = "4e638627673078c58b6e6bb789937822d83350ff"
|
||||||
repo-rev = "tb/flux"
|
repo-rev = "tb/flux"
|
||||||
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
||||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
|
@ -172,9 +172,9 @@ version = "0.10.3"
|
||||||
|
|
||||||
[[GPUArrays]]
|
[[GPUArrays]]
|
||||||
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
|
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
|
||||||
git-tree-sha1 = "b5009ac44b141ded5e6f04c4db83807970f56e91"
|
git-tree-sha1 = "77e27264276fe97a7e7fb928bf8999a145abc018"
|
||||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||||
version = "1.0.2"
|
version = "1.0.3"
|
||||||
|
|
||||||
[[IRTools]]
|
[[IRTools]]
|
||||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||||
|
|
|
@ -56,7 +56,7 @@ unbroadcast(x::AbstractArray, Δ) =
|
||||||
coerce_cuda(x::Union{CuArray,Nothing}) = x
|
coerce_cuda(x::Union{CuArray,Nothing}) = x
|
||||||
coerce_cuda(x::Tuple) = coerce_cuda.(x)
|
coerce_cuda(x::Tuple) = coerce_cuda.(x)
|
||||||
|
|
||||||
coerce_cuda(x) = x .+ CuArrays.fill(0)
|
coerce_cuda(x::AbstractArray) = x .+ CuArrays.fill(0)
|
||||||
|
|
||||||
function struct_grad!(cx::Zygote.Context, x, x̄)
|
function struct_grad!(cx::Zygote.Context, x, x̄)
|
||||||
for f in fieldnames(typeof(x))
|
for f in fieldnames(typeof(x))
|
||||||
|
@ -69,28 +69,23 @@ end
|
||||||
|
|
||||||
for RNN in (CuRNN, CuGRU)
|
for RNN in (CuRNN, CuGRU)
|
||||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h)
|
(y, ho), back = CUDNN.pullback(desc(m), x, h)
|
||||||
(ho, y), function (Δ)
|
(ho, y), function (Δ)
|
||||||
dho, dy = coerce_cuda(Δ)
|
dho, dy = coerce_cuda(Δ) # Support FillArrays etc.
|
||||||
h_ = CUDNN.hBatch(x, h)
|
m̄ = back(dy, dho)
|
||||||
dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve)
|
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing))
|
||||||
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
|
(dm, unbroadcast(h, m̄.h), m̄.x)
|
||||||
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
|
||||||
(dm, unbroadcast(h, dh), dx)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c)
|
(y, ho, co), back = CUDNN.pullback(desc(m), x, h, c)
|
||||||
((ho, co), y), function (Δ)
|
((ho, co), y), function (Δ)
|
||||||
dhc, dy = coerce_cuda(Δ)
|
dhc, dy = coerce_cuda(Δ) # Support FillArrays etc.
|
||||||
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||||
h_ = CUDNN.hBatch(x, h)
|
m̄ = back(dy, dho, dco)
|
||||||
c_ = CUDNN.hBatch(x, c)
|
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing))
|
||||||
dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
(dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x)
|
||||||
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
|
|
||||||
dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
|
||||||
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue