fix cuda init
This commit is contained in:
parent
787097f9ea
commit
6846551f57
|
@ -46,9 +46,9 @@ version = "0.6.2"
|
|||
|
||||
[[CUDAapi]]
|
||||
deps = ["Libdl", "Logging"]
|
||||
git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091"
|
||||
git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b"
|
||||
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
|
||||
[[CUDAdrv]]
|
||||
deps = ["CUDAapi", "Libdl", "Printf"]
|
||||
|
@ -106,9 +106,7 @@ version = "4.0.0"
|
|||
|
||||
[[CuArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||
git-tree-sha1 = "de756b0ed9ffe17890ce77b59bc76b10f96747e7"
|
||||
repo-rev = "stable"
|
||||
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
||||
git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9"
|
||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
version = "1.2.1"
|
||||
|
||||
|
@ -390,7 +388,7 @@ version = "0.8.3"
|
|||
|
||||
[[Zygote]]
|
||||
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "9186cb0b3b59219e4aba0840614d6a9d7282012e"
|
||||
git-tree-sha1 = "38241b40ebd8748bcacad5e6c7ba3ab3cc7a15c9"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
@ -398,6 +396,8 @@ version = "0.3.4"
|
|||
|
||||
[[ZygoteRules]]
|
||||
deps = ["MacroTools"]
|
||||
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6"
|
||||
git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/FluxML/ZygoteRules.jl.git"
|
||||
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
|
|
|
@ -24,6 +24,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
|||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||
|
||||
[compat]
|
||||
CUDAapi = "1.1"
|
||||
|
|
|
@ -83,12 +83,14 @@ function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
|||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(undef, 1)]
|
||||
const workspace = Ref{Union{Nothing,CuVector{UInt8}}}(nothing)
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(undef, bytes))
|
||||
function getworkspace(bytes)
|
||||
if workspace[] === nothing || length(workspace[]) < bytes
|
||||
workspace[] = CuVector{UInt8}(undef, bytes)
|
||||
end
|
||||
workspace[]
|
||||
end
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||
|
|
Loading…
Reference in New Issue