Add an environment variable to disable CUDA usage.

This commit is contained in:
Tim Besard 2019-10-03 21:10:20 +02:00
parent 63d196aa37
commit 2369b2b3fd

View File

@ -20,9 +20,18 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
const consider_cuda = allow_cuda()
using CUDAapi using CUDAapi
if has_cuda() if consider_cuda && has_cuda()
using CuArrays try
using CuArrays
catch
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
rethrow()
end
use_cuda() = true use_cuda() = true
else else
use_cuda() = false use_cuda() = false
@ -47,7 +56,9 @@ if use_cuda()
end end
function __init__() function __init__()
if has_cuda() != use_cuda() # check if the GPU usage conditions that are baked in the precompilation image
# match the current situation, and force a recompilation if not.
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda())
cachefile = if VERSION >= v"1.3-" cachefile = if VERSION >= v"1.3-"
Base.compilecache_path(Base.PkgId(Flux)) Base.compilecache_path(Base.PkgId(Flux))
else else