Demote to const variables.
This commit is contained in:
parent
2369b2b3fd
commit
8aea15e6e0
10
src/Flux.jl
10
src/Flux.jl
|
@ -25,16 +25,14 @@ allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
|
|||
const consider_cuda = allow_cuda()
|
||||
|
||||
using CUDAapi
|
||||
if consider_cuda && has_cuda()
|
||||
const use_cuda = consider_cuda && has_cuda()
|
||||
if use_cuda
|
||||
try
|
||||
using CuArrays
|
||||
catch
|
||||
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
||||
rethrow()
|
||||
end
|
||||
use_cuda() = true
|
||||
else
|
||||
use_cuda() = false
|
||||
end
|
||||
|
||||
include("utils.jl")
|
||||
|
@ -51,14 +49,14 @@ include("data/Data.jl")
|
|||
|
||||
include("deprecations.jl")
|
||||
|
||||
if use_cuda()
|
||||
if use_cuda
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
function __init__()
|
||||
# check if the GPU usage conditions that are baked in the precompilation image
|
||||
# match the current situation, and force a recompilation if not.
|
||||
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda())
|
||||
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda)
|
||||
cachefile = if VERSION >= v"1.3-"
|
||||
Base.compilecache_path(Base.PkgId(Flux))
|
||||
else
|
||||
|
|
|
@ -73,7 +73,7 @@ end
|
|||
|
||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||
|
||||
const gpu_adaptor = if use_cuda()
|
||||
const gpu_adaptor = if use_cuda
|
||||
CuArrays.cu
|
||||
else
|
||||
identity
|
||||
|
|
|
@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
|
|||
|
||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
if use_cuda()
|
||||
if use_cuda
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
|
|
Loading…
Reference in New Issue