From 63d196aa370def3ea9883fb30648f9eccdf98819 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 3 Oct 2019 19:54:23 +0200 Subject: [PATCH 1/3] Check if CUDA availability changed during init. --- src/Flux.jl | 25 ++++++++++++++++--------- src/functor.jl | 2 +- src/onehot.jl | 2 +- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 0b57f81d..911d2ab5 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -22,15 +22,10 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, using CUDAapi if has_cuda() - try - using CuArrays - @eval has_cuarrays() = true - catch ex - @warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace()) - @eval has_cuarrays() = false - end + using CuArrays + use_cuda() = true else - has_cuarrays() = false + use_cuda() = false end include("utils.jl") @@ -47,8 +42,20 @@ include("data/Data.jl") include("deprecations.jl") -if has_cuarrays() +if use_cuda() include("cuda/cuda.jl") end +function __init__() + if has_cuda() != use_cuda() + cachefile = if VERSION >= v"1.3-" + Base.compilecache_path(Base.PkgId(Flux)) + else + abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux))) + end + rm(cachefile) + error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.") + end +end + end # module diff --git a/src/functor.jl b/src/functor.jl index 73483ab9..a3e053b0 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -73,7 +73,7 @@ end cpu(m) = fmap(x -> adapt(Array, x), m) -const gpu_adaptor = if has_cuarrays() +const gpu_adaptor = if use_cuda() CuArrays.cu else identity diff --git a/src/onehot.jl b/src/onehot.jl index fe93c5c5..9bce5dd8 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) -if has_cuarrays() +if use_cuda() import .CuArrays: CuArray, cudaconvert import Base.Broadcast: BroadcastStyle, ArrayStyle BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}() From 2369b2b3fdc2b6fcd68b67e7f7776621474f28ed Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 3 Oct 2019 21:10:20 +0200 Subject: [PATCH 2/3] Add an environment variable to disable CUDA usage. --- src/Flux.jl | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 911d2ab5..c0023e2c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -20,9 +20,18 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay + +allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true")) +const consider_cuda = allow_cuda() + using CUDAapi -if has_cuda() - using CuArrays +if consider_cuda && has_cuda() + try + using CuArrays + catch + @error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false." + rethrow() + end use_cuda() = true else use_cuda() = false @@ -47,7 +56,9 @@ if use_cuda() end function __init__() - if has_cuda() != use_cuda() + # check if the GPU usage conditions that are baked in the precompilation image + # match the current situation, and force a recompilation if not. + if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda()) cachefile = if VERSION >= v"1.3-" Base.compilecache_path(Base.PkgId(Flux)) else From 8aea15e6e021e5055104694a87bc8ef6c54a2f48 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 3 Oct 2019 21:28:55 +0200 Subject: [PATCH 3/3] Demote to const variables. --- src/Flux.jl | 10 ++++------ src/functor.jl | 2 +- src/onehot.jl | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index c0023e2c..95bdcd32 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -25,16 +25,14 @@ allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true")) const consider_cuda = allow_cuda() using CUDAapi -if consider_cuda && has_cuda() +const use_cuda = consider_cuda && has_cuda() +if use_cuda try using CuArrays catch @error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false." rethrow() end - use_cuda() = true -else - use_cuda() = false end include("utils.jl") @@ -51,14 +49,14 @@ include("data/Data.jl") include("deprecations.jl") -if use_cuda() +if use_cuda include("cuda/cuda.jl") end function __init__() # check if the GPU usage conditions that are baked in the precompilation image # match the current situation, and force a recompilation if not. - if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda()) + if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda) cachefile = if VERSION >= v"1.3-" Base.compilecache_path(Base.PkgId(Flux)) else diff --git a/src/functor.jl b/src/functor.jl index a3e053b0..798445b4 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -73,7 +73,7 @@ end cpu(m) = fmap(x -> adapt(Array, x), m) -const gpu_adaptor = if use_cuda() +const gpu_adaptor = if use_cuda CuArrays.cu else identity diff --git a/src/onehot.jl b/src/onehot.jl index 9bce5dd8..84747450 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) -if use_cuda() +if use_cuda import .CuArrays: CuArray, cudaconvert import Base.Broadcast: BroadcastStyle, ArrayStyle BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()