Add an environment variable to disable CUDA usage.
This commit is contained in:
parent
63d196aa37
commit
2369b2b3fd
17
src/Flux.jl
17
src/Flux.jl
@ -20,9 +20,18 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||
|
||||
|
||||
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
|
||||
const consider_cuda = allow_cuda()
|
||||
|
||||
using CUDAapi
|
||||
if has_cuda()
|
||||
using CuArrays
|
||||
if consider_cuda && has_cuda()
|
||||
try
|
||||
using CuArrays
|
||||
catch
|
||||
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
||||
rethrow()
|
||||
end
|
||||
use_cuda() = true
|
||||
else
|
||||
use_cuda() = false
|
||||
@ -47,7 +56,9 @@ if use_cuda()
|
||||
end
|
||||
|
||||
function __init__()
|
||||
if has_cuda() != use_cuda()
|
||||
# check if the GPU usage conditions that are baked in the precompilation image
|
||||
# match the current situation, and force a recompilation if not.
|
||||
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda())
|
||||
cachefile = if VERSION >= v"1.3-"
|
||||
Base.compilecache_path(Base.PkgId(Flux))
|
||||
else
|
||||
|
Loading…
Reference in New Issue
Block a user