fix cuda init

This commit is contained in:
Dhairya Gandhi 2019-09-22 22:02:05 +05:30
parent 787097f9ea
commit 6846551f57
3 changed files with 16 additions and 13 deletions

View File

@ -46,9 +46,9 @@ version = "0.6.2"
[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091"
git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "1.1.0"
version = "1.2.0"
[[CUDAdrv]]
deps = ["CUDAapi", "Libdl", "Printf"]
@ -106,9 +106,7 @@ version = "4.0.0"
[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "de756b0ed9ffe17890ce77b59bc76b10f96747e7"
repo-rev = "stable"
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "1.2.1"
@ -390,7 +388,7 @@ version = "0.8.3"
[[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "9186cb0b3b59219e4aba0840614d6a9d7282012e"
git-tree-sha1 = "38241b40ebd8748bcacad5e6c7ba3ab3cc7a15c9"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
@ -398,6 +396,8 @@ version = "0.3.4"
[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6"
git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0"
repo-rev = "master"
repo-url = "https://github.com/FluxML/ZygoteRules.jl.git"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.1.0"
version = "0.2.0"

View File

@ -24,6 +24,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
CUDAapi = "1.1"

View File

@ -83,12 +83,14 @@ function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
return Int(size[])
end
const workspace = [CuVector{UInt8}(undef, 1)]
const workspace = Ref{Union{Nothing,CuVector{UInt8}}}(nothing)
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(undef, bytes))
function getworkspace(bytes)
if workspace[] === nothing || length(workspace[]) < bytes
workspace[] = CuVector{UInt8}(undef, bytes)
end
workspace[]
end
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))