diff --git a/src/Flux.jl b/src/Flux.jl index 0b57f81d..95bdcd32 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -20,17 +20,19 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay + +allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true")) +const consider_cuda = allow_cuda() + using CUDAapi -if has_cuda() +const use_cuda = consider_cuda && has_cuda() +if use_cuda try using CuArrays - @eval has_cuarrays() = true - catch ex - @warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace()) - @eval has_cuarrays() = false + catch + @error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false." + rethrow() end -else - has_cuarrays() = false end include("utils.jl") @@ -47,8 +49,22 @@ include("data/Data.jl") include("deprecations.jl") -if has_cuarrays() +if use_cuda include("cuda/cuda.jl") end +function __init__() + # check if the GPU usage conditions that are baked in the precompilation image + # match the current situation, and force a recompilation if not. + if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda) + cachefile = if VERSION >= v"1.3-" + Base.compilecache_path(Base.PkgId(Flux)) + else + abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux))) + end + rm(cachefile) + error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.") + end +end + end # module diff --git a/src/functor.jl b/src/functor.jl index 1d3e1bb2..b96d21c8 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -73,7 +73,7 @@ end cpu(m) = fmap(x -> adapt(Array, x), m) -const gpu_adaptor = if has_cuarrays() +const gpu_adaptor = if use_cuda CuArrays.cu else identity diff --git a/src/onehot.jl b/src/onehot.jl index fe93c5c5..84747450 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) -if has_cuarrays() +if use_cuda import .CuArrays: CuArray, cudaconvert import Base.Broadcast: BroadcastStyle, ArrayStyle BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()