Add an environment variable to disable CUDA usage.
This commit is contained in:
parent
63d196aa37
commit
2369b2b3fd
17
src/Flux.jl
17
src/Flux.jl
@ -20,9 +20,18 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
|||||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||||
|
|
||||||
|
|
||||||
|
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
|
||||||
|
const consider_cuda = allow_cuda()
|
||||||
|
|
||||||
using CUDAapi
|
using CUDAapi
|
||||||
if has_cuda()
|
if consider_cuda && has_cuda()
|
||||||
using CuArrays
|
try
|
||||||
|
using CuArrays
|
||||||
|
catch
|
||||||
|
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
||||||
|
rethrow()
|
||||||
|
end
|
||||||
use_cuda() = true
|
use_cuda() = true
|
||||||
else
|
else
|
||||||
use_cuda() = false
|
use_cuda() = false
|
||||||
@ -47,7 +56,9 @@ if use_cuda()
|
|||||||
end
|
end
|
||||||
|
|
||||||
function __init__()
|
function __init__()
|
||||||
if has_cuda() != use_cuda()
|
# check if the GPU usage conditions that are baked in the precompilation image
|
||||||
|
# match the current situation, and force a recompilation if not.
|
||||||
|
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda())
|
||||||
cachefile = if VERSION >= v"1.3-"
|
cachefile = if VERSION >= v"1.3-"
|
||||||
Base.compilecache_path(Base.PkgId(Flux))
|
Base.compilecache_path(Base.PkgId(Flux))
|
||||||
else
|
else
|
||||||
|
Loading…
Reference in New Issue
Block a user