diff --git a/src/Flux.jl b/src/Flux.jl index 0b57f81d..911d2ab5 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -22,15 +22,10 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, using CUDAapi if has_cuda() - try - using CuArrays - @eval has_cuarrays() = true - catch ex - @warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace()) - @eval has_cuarrays() = false - end + using CuArrays + use_cuda() = true else - has_cuarrays() = false + use_cuda() = false end include("utils.jl") @@ -47,8 +42,20 @@ include("data/Data.jl") include("deprecations.jl") -if has_cuarrays() +if use_cuda() include("cuda/cuda.jl") end +function __init__() + if has_cuda() != use_cuda() + cachefile = if VERSION >= v"1.3-" + Base.compilecache_path(Base.PkgId(Flux)) + else + abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux))) + end + rm(cachefile) + error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.") + end +end + end # module diff --git a/src/functor.jl b/src/functor.jl index 73483ab9..a3e053b0 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -73,7 +73,7 @@ end cpu(m) = fmap(x -> adapt(Array, x), m) -const gpu_adaptor = if has_cuarrays() +const gpu_adaptor = if use_cuda() CuArrays.cu else identity diff --git a/src/onehot.jl b/src/onehot.jl index fe93c5c5..9bce5dd8 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) -if has_cuarrays() +if use_cuda() import .CuArrays: CuArray, cudaconvert import Base.Broadcast: BroadcastStyle, ArrayStyle BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()