Add an environment variable to disable CUDA usage.

This commit is contained in:
Tim Besard 2019-10-03 21:10:20 +02:00
parent 63d196aa37
commit 2369b2b3fd

View File

@ -20,9 +20,18 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
const consider_cuda = allow_cuda()
using CUDAapi
if has_cuda()
using CuArrays
if consider_cuda && has_cuda()
try
using CuArrays
catch
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
rethrow()
end
use_cuda() = true
else
use_cuda() = false
@ -47,7 +56,9 @@ if use_cuda()
end
function __init__()
if has_cuda() != use_cuda()
# check if the GPU usage conditions that are baked in the precompilation image
# match the current situation, and force a recompilation if not.
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda())
cachefile = if VERSION >= v"1.3-"
Base.compilecache_path(Base.PkgId(Flux))
else