Demote to const variables.

This commit is contained in:
Tim Besard 2019-10-03 21:28:55 +02:00
parent 2369b2b3fd
commit 8aea15e6e0
3 changed files with 6 additions and 8 deletions

View File

@ -25,16 +25,14 @@ allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
const consider_cuda = allow_cuda() const consider_cuda = allow_cuda()
using CUDAapi using CUDAapi
if consider_cuda && has_cuda() const use_cuda = consider_cuda && has_cuda()
if use_cuda
try try
using CuArrays using CuArrays
catch catch
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false." @error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
rethrow() rethrow()
end end
use_cuda() = true
else
use_cuda() = false
end end
include("utils.jl") include("utils.jl")
@ -51,14 +49,14 @@ include("data/Data.jl")
include("deprecations.jl") include("deprecations.jl")
if use_cuda() if use_cuda
include("cuda/cuda.jl") include("cuda/cuda.jl")
end end
function __init__() function __init__()
# check if the GPU usage conditions that are baked in the precompilation image # check if the GPU usage conditions that are baked in the precompilation image
# match the current situation, and force a recompilation if not. # match the current situation, and force a recompilation if not.
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda()) if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda)
cachefile = if VERSION >= v"1.3-" cachefile = if VERSION >= v"1.3-"
Base.compilecache_path(Base.PkgId(Flux)) Base.compilecache_path(Base.PkgId(Flux))
else else

View File

@ -73,7 +73,7 @@ end
cpu(m) = fmap(x -> adapt(Array, x), m) cpu(m) = fmap(x -> adapt(Array, x), m)
const gpu_adaptor = if use_cuda() const gpu_adaptor = if use_cuda
CuArrays.cu CuArrays.cu
else else
identity identity

View File

@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
if use_cuda() if use_cuda
import .CuArrays: CuArray, cudaconvert import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}() BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()