Merge #882
882: Check if CUDA availability changed during init. r=MikeInnes a=maleadt With this PR, Flux checks using CUDAapi if CUDA is available during initialization, and forces recompilation if that does not agree with what was decided during precompilation. This avoids the scenario where Flux was precompiled without GPU support, consequently not allowing use of the GPU even if the user fixed his CUDA/GPU set-up because that does not force recompilation (and we can't add precompilation dependencies on stuff that doesn't exist). However, we can't do the same for the case where we have a GPU/CUDA but CuArrays fails to import (checking if it imports during `__init__` would be much too expensive, if even possible), so this PR removes support for having CUDA/a GPU but CuArrays being broken. That's a little risky now that Flux depends on CuArrays, but the package is pretty mature and I haven't seen many bug reports failing to load it recently. Fixes https://github.com/FluxML/Flux.jl/pull/852#issuecomment-538028314 cc @MikeInnes @xukai92 Co-authored-by: Tim Besard <tim.besard@gmail.com>
This commit is contained in:
commit
af0dcb2c63
32
src/Flux.jl
32
src/Flux.jl
@ -20,17 +20,19 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||
|
||||
|
||||
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
|
||||
const consider_cuda = allow_cuda()
|
||||
|
||||
using CUDAapi
|
||||
if has_cuda()
|
||||
const use_cuda = consider_cuda && has_cuda()
|
||||
if use_cuda
|
||||
try
|
||||
using CuArrays
|
||||
@eval has_cuarrays() = true
|
||||
catch ex
|
||||
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())
|
||||
@eval has_cuarrays() = false
|
||||
catch
|
||||
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
||||
rethrow()
|
||||
end
|
||||
else
|
||||
has_cuarrays() = false
|
||||
end
|
||||
|
||||
include("utils.jl")
|
||||
@ -47,8 +49,22 @@ include("data/Data.jl")
|
||||
|
||||
include("deprecations.jl")
|
||||
|
||||
if has_cuarrays()
|
||||
if use_cuda
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
function __init__()
|
||||
# check if the GPU usage conditions that are baked in the precompilation image
|
||||
# match the current situation, and force a recompilation if not.
|
||||
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda)
|
||||
cachefile = if VERSION >= v"1.3-"
|
||||
Base.compilecache_path(Base.PkgId(Flux))
|
||||
else
|
||||
abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux)))
|
||||
end
|
||||
rm(cachefile)
|
||||
error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.")
|
||||
end
|
||||
end
|
||||
|
||||
end # module
|
||||
|
@ -73,7 +73,7 @@ end
|
||||
|
||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||
|
||||
const gpu_adaptor = if has_cuarrays()
|
||||
const gpu_adaptor = if use_cuda
|
||||
CuArrays.cu
|
||||
else
|
||||
identity
|
||||
|
@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
|
||||
|
||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
if has_cuarrays()
|
||||
if use_cuda
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
|
Loading…
Reference in New Issue
Block a user