Demote to const variables.
This commit is contained in:
parent
2369b2b3fd
commit
8aea15e6e0
10
src/Flux.jl
10
src/Flux.jl
@ -25,16 +25,14 @@ allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
|
|||||||
const consider_cuda = allow_cuda()
|
const consider_cuda = allow_cuda()
|
||||||
|
|
||||||
using CUDAapi
|
using CUDAapi
|
||||||
if consider_cuda && has_cuda()
|
const use_cuda = consider_cuda && has_cuda()
|
||||||
|
if use_cuda
|
||||||
try
|
try
|
||||||
using CuArrays
|
using CuArrays
|
||||||
catch
|
catch
|
||||||
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
||||||
rethrow()
|
rethrow()
|
||||||
end
|
end
|
||||||
use_cuda() = true
|
|
||||||
else
|
|
||||||
use_cuda() = false
|
|
||||||
end
|
end
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
@ -51,14 +49,14 @@ include("data/Data.jl")
|
|||||||
|
|
||||||
include("deprecations.jl")
|
include("deprecations.jl")
|
||||||
|
|
||||||
if use_cuda()
|
if use_cuda
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
function __init__()
|
function __init__()
|
||||||
# check if the GPU usage conditions that are baked in the precompilation image
|
# check if the GPU usage conditions that are baked in the precompilation image
|
||||||
# match the current situation, and force a recompilation if not.
|
# match the current situation, and force a recompilation if not.
|
||||||
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda())
|
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda)
|
||||||
cachefile = if VERSION >= v"1.3-"
|
cachefile = if VERSION >= v"1.3-"
|
||||||
Base.compilecache_path(Base.PkgId(Flux))
|
Base.compilecache_path(Base.PkgId(Flux))
|
||||||
else
|
else
|
||||||
|
@ -73,7 +73,7 @@ end
|
|||||||
|
|
||||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||||
|
|
||||||
const gpu_adaptor = if use_cuda()
|
const gpu_adaptor = if use_cuda
|
||||||
CuArrays.cu
|
CuArrays.cu
|
||||||
else
|
else
|
||||||
identity
|
identity
|
||||||
|
@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
|
|||||||
|
|
||||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
if use_cuda()
|
if use_cuda
|
||||||
import .CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
|
Loading…
Reference in New Issue
Block a user