From 2369b2b3fdc2b6fcd68b67e7f7776621474f28ed Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 3 Oct 2019 21:10:20 +0200 Subject: [PATCH] Add an environment variable to disable CUDA usage. --- src/Flux.jl | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 911d2ab5..c0023e2c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -20,9 +20,18 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay + +allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true")) +const consider_cuda = allow_cuda() + using CUDAapi -if has_cuda() - using CuArrays +if consider_cuda && has_cuda() + try + using CuArrays + catch + @error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false." + rethrow() + end use_cuda() = true else use_cuda() = false @@ -47,7 +56,9 @@ if use_cuda() end function __init__() - if has_cuda() != use_cuda() + # check if the GPU usage conditions that are baked in the precompilation image + # match the current situation, and force a recompilation if not. + if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda()) cachefile = if VERSION >= v"1.3-" Base.compilecache_path(Base.PkgId(Flux)) else