Check if CUDA availability changed during init.
This commit is contained in:
parent
e2b93bc78a
commit
63d196aa37
25
src/Flux.jl
25
src/Flux.jl
@ -22,15 +22,10 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
|||||||
|
|
||||||
using CUDAapi
|
using CUDAapi
|
||||||
if has_cuda()
|
if has_cuda()
|
||||||
try
|
using CuArrays
|
||||||
using CuArrays
|
use_cuda() = true
|
||||||
@eval has_cuarrays() = true
|
|
||||||
catch ex
|
|
||||||
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())
|
|
||||||
@eval has_cuarrays() = false
|
|
||||||
end
|
|
||||||
else
|
else
|
||||||
has_cuarrays() = false
|
use_cuda() = false
|
||||||
end
|
end
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
@ -47,8 +42,20 @@ include("data/Data.jl")
|
|||||||
|
|
||||||
include("deprecations.jl")
|
include("deprecations.jl")
|
||||||
|
|
||||||
if has_cuarrays()
|
if use_cuda()
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function __init__()
|
||||||
|
if has_cuda() != use_cuda()
|
||||||
|
cachefile = if VERSION >= v"1.3-"
|
||||||
|
Base.compilecache_path(Base.PkgId(Flux))
|
||||||
|
else
|
||||||
|
abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux)))
|
||||||
|
end
|
||||||
|
rm(cachefile)
|
||||||
|
error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -73,7 +73,7 @@ end
|
|||||||
|
|
||||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||||
|
|
||||||
const gpu_adaptor = if has_cuarrays()
|
const gpu_adaptor = if use_cuda()
|
||||||
CuArrays.cu
|
CuArrays.cu
|
||||||
else
|
else
|
||||||
identity
|
identity
|
||||||
|
@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
|
|||||||
|
|
||||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
if has_cuarrays()
|
if use_cuda()
|
||||||
import .CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
|
Loading…
Reference in New Issue
Block a user