Check if CUDA availability changed during init.

This commit is contained in:
Tim Besard 2019-10-03 19:54:23 +02:00
parent e2b93bc78a
commit 63d196aa37
3 changed files with 18 additions and 11 deletions

View File

@ -22,15 +22,10 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
using CUDAapi
if has_cuda()
try
using CuArrays
@eval has_cuarrays() = true
catch ex
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())
@eval has_cuarrays() = false
end
using CuArrays
use_cuda() = true
else
has_cuarrays() = false
use_cuda() = false
end
include("utils.jl")
@ -47,8 +42,20 @@ include("data/Data.jl")
include("deprecations.jl")
if has_cuarrays()
if use_cuda()
include("cuda/cuda.jl")
end
function __init__()
if has_cuda() != use_cuda()
cachefile = if VERSION >= v"1.3-"
Base.compilecache_path(Base.PkgId(Flux))
else
abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux)))
end
rm(cachefile)
error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.")
end
end
end # module

View File

@ -73,7 +73,7 @@ end
cpu(m) = fmap(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
const gpu_adaptor = if use_cuda()
CuArrays.cu
else
identity

View File

@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
if has_cuarrays()
if use_cuda()
import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()