882: Check if CUDA availability changed during init. r=MikeInnes a=maleadt

With this PR, Flux checks using CUDAapi if CUDA is available during initialization, and forces recompilation if that does not agree with what was decided during precompilation. This avoids the scenario where Flux was precompiled without GPU support, consequently not allowing use of the GPU even if the user fixed his CUDA/GPU set-up because that does not force recompilation (and we can't add precompilation dependencies on stuff that doesn't exist).

However, we can't do the same for the case where we have a GPU/CUDA but CuArrays fails to import (checking if it imports during `__init__` would be much too expensive, if even possible), so this PR removes support for having CUDA/a GPU but CuArrays being broken. That's a little risky now that Flux depends on CuArrays, but the package is pretty mature and I haven't seen many bug reports failing to load it recently.

Fixes https://github.com/FluxML/Flux.jl/pull/852#issuecomment-538028314

cc @MikeInnes @xukai92

Co-authored-by: Tim Besard <tim.besard@gmail.com>
This commit is contained in:
bors[bot] 2019-10-08 13:24:49 +00:00 committed by GitHub
commit af0dcb2c63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 10 deletions

View File

@ -20,17 +20,19 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
const consider_cuda = allow_cuda()
using CUDAapi
if has_cuda()
const use_cuda = consider_cuda && has_cuda()
if use_cuda
try
using CuArrays
@eval has_cuarrays() = true
catch ex
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())
@eval has_cuarrays() = false
catch
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
rethrow()
end
else
has_cuarrays() = false
end
include("utils.jl")
@ -47,8 +49,22 @@ include("data/Data.jl")
include("deprecations.jl")
if has_cuarrays()
if use_cuda
include("cuda/cuda.jl")
end
function __init__()
# check if the GPU usage conditions that are baked in the precompilation image
# match the current situation, and force a recompilation if not.
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda)
cachefile = if VERSION >= v"1.3-"
Base.compilecache_path(Base.PkgId(Flux))
else
abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux)))
end
rm(cachefile)
error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.")
end
end
end # module

View File

@ -73,7 +73,7 @@ end
cpu(m) = fmap(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
const gpu_adaptor = if use_cuda
CuArrays.cu
else
identity

View File

@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
if has_cuarrays()
if use_cuda
import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()