diff --git a/src/Flux.jl b/src/Flux.jl index 78670e65..5afa1fc0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -39,24 +39,13 @@ include("data/Data.jl") include("deprecations.jl") +include("cuda/cuda.jl") + function __init__() - precompiling = ccall(:jl_generating_output, Cint, ()) != 0 - - # we don't want to include the CUDA module when precompiling, - # or we could end up replacing it at run time (triggering a warning) - precompiling && return - - if !CuArrays.functional() - # nothing to do here, and either CuArrays or one of its dependencies will have warned - else - use_cuda[] = true - - # FIXME: this functionality should be conditional at run time by checking `use_cuda` - # (or even better, get moved to CuArrays.jl as much as possible) - if CuArrays.has_cudnn() - include(joinpath(@__DIR__, "cuda/cuda.jl")) - else - @warn "CuArrays.jl did not find libcudnn. Some functionality will not be available." + use_cuda[] = CuArrays.functional() # Can be overridden after load with `Flux.use_cuda[] = false` + if CuArrays.functional() + if !CuArrays.has_cudnn() + @warn "CuArrays.jl found cuda, but did not find libcudnn. Some functionality will not be available." end end end