Check for CUDA availability at run time.
This commit is contained in:
parent
7104fd9332
commit
39ab740fb7
@ -5,7 +5,7 @@ version = "0.9.0"
|
|||||||
[deps]
|
[deps]
|
||||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||||
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
|
40
src/Flux.jl
40
src/Flux.jl
@ -21,19 +21,9 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
|||||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||||
|
|
||||||
|
|
||||||
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
|
ENV["CUDA_INIT_SILENT"] = true
|
||||||
const consider_cuda = allow_cuda()
|
using CUDAdrv, CuArrays
|
||||||
|
const use_cuda = Ref(false)
|
||||||
using CUDAapi
|
|
||||||
const use_cuda = consider_cuda && has_cuda()
|
|
||||||
if use_cuda
|
|
||||||
try
|
|
||||||
using CuArrays
|
|
||||||
catch
|
|
||||||
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
|
||||||
rethrow()
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
@ -49,21 +39,19 @@ include("data/Data.jl")
|
|||||||
|
|
||||||
include("deprecations.jl")
|
include("deprecations.jl")
|
||||||
|
|
||||||
if use_cuda
|
include("cuda/cuda.jl")
|
||||||
include("cuda/cuda.jl")
|
|
||||||
end
|
|
||||||
|
|
||||||
function __init__()
|
function __init__()
|
||||||
# check if the GPU usage conditions that are baked in the precompilation image
|
if !CUDAdrv.functional()
|
||||||
# match the current situation, and force a recompilation if not.
|
@warn "CUDA available, but CUDAdrv.jl failed to load"
|
||||||
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda)
|
elseif length(devices()) == 0
|
||||||
cachefile = if VERSION >= v"1.3-"
|
@warn "CUDA available, but no GPU detected"
|
||||||
Base.compilecache_path(Base.PkgId(Flux))
|
elseif !CuArrays.functional()
|
||||||
else
|
@warn "CUDA GPU available, but CuArrays.jl failed to load"
|
||||||
abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux)))
|
elseif !CuArrays.has_cudnn()
|
||||||
end
|
@warn "CUDA GPU available, but CuArrays.jl did not find libcudnn"
|
||||||
rm(cachefile)
|
else
|
||||||
error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.")
|
use_cuda[] = true
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -2,12 +2,8 @@ module CUDA
|
|||||||
|
|
||||||
using ..CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
if CuArrays.libcudnn !== nothing # TODO: use CuArrays.has_cudnn()
|
using CuArrays: CUDNN
|
||||||
using CuArrays: CUDNN
|
include("curnn.jl")
|
||||||
include("curnn.jl")
|
include("cudnn.jl")
|
||||||
include("cudnn.jl")
|
|
||||||
else
|
|
||||||
@warn "CUDNN is not installed, some functionality will not be available."
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -73,13 +73,7 @@ end
|
|||||||
|
|
||||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||||
|
|
||||||
const gpu_adaptor = if use_cuda
|
gpu(x) = use_cuda[] ? fmap(CuArrays.cu, x) : x
|
||||||
CuArrays.cu
|
|
||||||
else
|
|
||||||
identity
|
|
||||||
end
|
|
||||||
|
|
||||||
gpu(x) = fmap(gpu_adaptor, x)
|
|
||||||
|
|
||||||
# Precision
|
# Precision
|
||||||
|
|
||||||
|
@ -37,12 +37,10 @@ import Adapt: adapt, adapt_structure
|
|||||||
|
|
||||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
if use_cuda
|
import .CuArrays: CuArray, cudaconvert
|
||||||
import .CuArrays: CuArray, cudaconvert
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
onehot(l, labels[, unk])
|
onehot(l, labels[, unk])
|
||||||
|
@ -53,8 +53,6 @@ end
|
|||||||
@test y[3,:] isa CuArray
|
@test y[3,:] isa CuArray
|
||||||
end
|
end
|
||||||
|
|
||||||
if CuArrays.libcudnn != nothing
|
@info "Testing Flux/CUDNN"
|
||||||
@info "Testing Flux/CUDNN"
|
include("cudnn.jl")
|
||||||
include("cudnn.jl")
|
include("curnn.jl")
|
||||||
include("curnn.jl")
|
|
||||||
end
|
|
||||||
|
Loading…
Reference in New Issue
Block a user