Merge pull request #916 from FluxML/tb/runtime_use_cuda
Check for CUDA availability at run time.
This commit is contained in:
commit
08804a06d2
@ -1,51 +1,27 @@
|
|||||||
before_script:
|
|
||||||
- export CI_DISABLE_CURNN_TEST=true
|
|
||||||
|
|
||||||
variables:
|
|
||||||
CI_IMAGE_TAG: 'cuda'
|
|
||||||
|
|
||||||
include:
|
include:
|
||||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v4/common.yml'
|
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v6.yml'
|
||||||
|
|
||||||
.flux:
|
image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
|
||||||
extends: .test
|
|
||||||
script:
|
|
||||||
- julia -e 'using InteractiveUtils;
|
|
||||||
versioninfo()'
|
|
||||||
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
|
|
||||||
- julia --project -e 'using Pkg;
|
|
||||||
Pkg.instantiate();
|
|
||||||
Pkg.build();
|
|
||||||
Pkg.test(; coverage=true);'
|
|
||||||
|
|
||||||
test:v1.0:
|
|
||||||
extends: .flux
|
|
||||||
variables:
|
|
||||||
CI_VERSION_TAG: 'v1.0'
|
|
||||||
|
|
||||||
test:v1.1:
|
julia:1.2:
|
||||||
extends: .flux
|
extends:
|
||||||
variables:
|
- .julia:1.2
|
||||||
CI_VERSION_TAG: 'v1.1'
|
- .test
|
||||||
|
tags:
|
||||||
|
- nvidia
|
||||||
|
|
||||||
test:v1.2:
|
julia:1.3:
|
||||||
extends: .flux
|
extends:
|
||||||
variables:
|
- .julia:1.3
|
||||||
CI_VERSION_TAG: 'v1.2'
|
- .test
|
||||||
|
tags:
|
||||||
test:v1.3:
|
- nvidia
|
||||||
extends: .flux
|
|
||||||
variables:
|
|
||||||
CI_VERSION_TAG: 'v1.3'
|
|
||||||
|
|
||||||
test:v1.0:
|
|
||||||
extends: .flux
|
|
||||||
variables:
|
|
||||||
CI_VERSION_TAG: 'v1.0'
|
|
||||||
|
|
||||||
test:dev:
|
|
||||||
extends: .flux
|
|
||||||
variables:
|
|
||||||
CI_VERSION_TAG: 'dev'
|
|
||||||
|
|
||||||
|
julia:nightly:
|
||||||
|
extends:
|
||||||
|
- .julia:nightly
|
||||||
|
- .test
|
||||||
|
tags:
|
||||||
|
- nvidia
|
||||||
allow_failure: true
|
allow_failure: true
|
||||||
|
@ -6,7 +6,8 @@ os:
|
|||||||
# - osx
|
# - osx
|
||||||
|
|
||||||
julia:
|
julia:
|
||||||
- 1.1
|
- 1.2
|
||||||
|
- 1.3
|
||||||
- nightly
|
- nightly
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
@ -16,7 +17,7 @@ matrix:
|
|||||||
jobs:
|
jobs:
|
||||||
include:
|
include:
|
||||||
- stage: "Documentation"
|
- stage: "Documentation"
|
||||||
julia: 1.0
|
julia: 1.2
|
||||||
os: linux
|
os: linux
|
||||||
script:
|
script:
|
||||||
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
|
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
|
||||||
|
@ -28,10 +28,10 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
|
|||||||
version = "0.8.10"
|
version = "0.8.10"
|
||||||
|
|
||||||
[[BinaryProvider]]
|
[[BinaryProvider]]
|
||||||
deps = ["Libdl", "Logging", "SHA"]
|
deps = ["Libdl", "SHA"]
|
||||||
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
|
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
|
||||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||||
version = "0.5.6"
|
version = "0.5.8"
|
||||||
|
|
||||||
[[CEnum]]
|
[[CEnum]]
|
||||||
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
|
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
|
||||||
@ -40,9 +40,9 @@ version = "0.2.0"
|
|||||||
|
|
||||||
[[CSTParser]]
|
[[CSTParser]]
|
||||||
deps = ["Tokenize"]
|
deps = ["Tokenize"]
|
||||||
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
|
git-tree-sha1 = "99dda94f5af21a4565dc2b97edf6a95485f116c3"
|
||||||
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
||||||
version = "0.6.2"
|
version = "1.0.0"
|
||||||
|
|
||||||
[[CUDAapi]]
|
[[CUDAapi]]
|
||||||
deps = ["Libdl", "Logging"]
|
deps = ["Libdl", "Logging"]
|
||||||
@ -51,16 +51,16 @@ uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
|||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
|
||||||
[[CUDAdrv]]
|
[[CUDAdrv]]
|
||||||
deps = ["CUDAapi", "Libdl", "Printf"]
|
deps = ["CEnum", "Printf"]
|
||||||
git-tree-sha1 = "9ce99b5732c70e06ed97c042187baed876fb1698"
|
git-tree-sha1 = "96eabc95ebb83e361311330ffb574a3e2df73251"
|
||||||
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||||
version = "3.1.0"
|
version = "4.0.2"
|
||||||
|
|
||||||
[[CUDAnative]]
|
[[CUDAnative]]
|
||||||
deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"]
|
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
|
||||||
git-tree-sha1 = "52ae1ce10ebfa686e227655c47b19add89308623"
|
git-tree-sha1 = "861a1a9e9741cc55c973a4688079f467a72337a7"
|
||||||
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
||||||
version = "2.3.1"
|
version = "2.5.1"
|
||||||
|
|
||||||
[[CodecZlib]]
|
[[CodecZlib]]
|
||||||
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
|
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
|
||||||
@ -88,9 +88,9 @@ version = "0.2.0"
|
|||||||
|
|
||||||
[[Compat]]
|
[[Compat]]
|
||||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||||
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
|
||||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||||
version = "2.1.0"
|
version = "2.2.0"
|
||||||
|
|
||||||
[[Conda]]
|
[[Conda]]
|
||||||
deps = ["JSON", "VersionParsing"]
|
deps = ["JSON", "VersionParsing"]
|
||||||
@ -105,23 +105,21 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
|||||||
version = "4.0.0"
|
version = "4.0.0"
|
||||||
|
|
||||||
[[CuArrays]]
|
[[CuArrays]]
|
||||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||||
git-tree-sha1 = "45683305171430978c17f496969dc9b6d3094a51"
|
git-tree-sha1 = "0d22d5a55e30e98617f258bb23688f141bfeae36"
|
||||||
repo-rev = "master"
|
|
||||||
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
|
||||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
version = "1.3.0"
|
version = "1.4.1"
|
||||||
|
|
||||||
[[DataAPI]]
|
[[DataAPI]]
|
||||||
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
|
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
|
||||||
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
|
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
|
||||||
version = "1.0.1"
|
version = "1.1.0"
|
||||||
|
|
||||||
[[DataStructures]]
|
[[DataStructures]]
|
||||||
deps = ["InteractiveUtils", "OrderedCollections"]
|
deps = ["InteractiveUtils", "OrderedCollections"]
|
||||||
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
|
git-tree-sha1 = "1fe8fad5fc84686dcbc674aa255bc867a64f8132"
|
||||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||||
version = "0.17.0"
|
version = "0.17.5"
|
||||||
|
|
||||||
[[Dates]]
|
[[Dates]]
|
||||||
deps = ["Printf"]
|
deps = ["Printf"]
|
||||||
@ -155,9 +153,9 @@ version = "1.0.1"
|
|||||||
|
|
||||||
[[FillArrays]]
|
[[FillArrays]]
|
||||||
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
||||||
git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
|
git-tree-sha1 = "de38b0253ade98340fabaf220f368f6144541938"
|
||||||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
||||||
version = "0.6.4"
|
version = "0.7.4"
|
||||||
|
|
||||||
[[FixedPointNumbers]]
|
[[FixedPointNumbers]]
|
||||||
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
|
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
|
||||||
@ -165,16 +163,16 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
|||||||
version = "0.6.1"
|
version = "0.6.1"
|
||||||
|
|
||||||
[[ForwardDiff]]
|
[[ForwardDiff]]
|
||||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
|
||||||
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
git-tree-sha1 = "adf88d6da1f0294058f38295becf8807986bb7d0"
|
||||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||||
version = "0.10.3"
|
version = "0.10.5"
|
||||||
|
|
||||||
[[GPUArrays]]
|
[[GPUArrays]]
|
||||||
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
|
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "Test"]
|
||||||
git-tree-sha1 = "77e27264276fe97a7e7fb928bf8999a145abc018"
|
git-tree-sha1 = "8d74ced24448c52b539a23d107bd2424ee139c0f"
|
||||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||||
version = "1.0.3"
|
version = "1.0.4"
|
||||||
|
|
||||||
[[IRTools]]
|
[[IRTools]]
|
||||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||||
@ -200,9 +198,9 @@ version = "0.7.2"
|
|||||||
|
|
||||||
[[LLVM]]
|
[[LLVM]]
|
||||||
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
|
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
|
||||||
git-tree-sha1 = "4a05f742837779a00bd8c9a18da6817367c4245d"
|
git-tree-sha1 = "74fe444b8b6d1ac01d639b2f9eaf395bcc2e24fc"
|
||||||
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
|
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
|
||||||
version = "1.3.0"
|
version = "1.3.2"
|
||||||
|
|
||||||
[[LibGit2]]
|
[[LibGit2]]
|
||||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||||
@ -234,9 +232,10 @@ uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
|||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
|
||||||
[[Missings]]
|
[[Missings]]
|
||||||
git-tree-sha1 = "29858ce6c8ae629cf2d733bffa329619a1c843d0"
|
deps = ["DataAPI"]
|
||||||
|
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
|
||||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||||
version = "0.4.2"
|
version = "0.4.3"
|
||||||
|
|
||||||
[[Mmap]]
|
[[Mmap]]
|
||||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||||
@ -261,12 +260,12 @@ version = "1.1.0"
|
|||||||
|
|
||||||
[[Parsers]]
|
[[Parsers]]
|
||||||
deps = ["Dates", "Test"]
|
deps = ["Dates", "Test"]
|
||||||
git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
|
git-tree-sha1 = "c56ecb484f286639f161e712b8311f5ab77e8d32"
|
||||||
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
||||||
version = "0.3.7"
|
version = "0.3.8"
|
||||||
|
|
||||||
[[Pkg]]
|
[[Pkg]]
|
||||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||||
|
|
||||||
[[Printf]]
|
[[Printf]]
|
||||||
@ -328,9 +327,9 @@ version = "0.8.0"
|
|||||||
|
|
||||||
[[StaticArrays]]
|
[[StaticArrays]]
|
||||||
deps = ["LinearAlgebra", "Random", "Statistics"]
|
deps = ["LinearAlgebra", "Random", "Statistics"]
|
||||||
git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
|
git-tree-sha1 = "1e9c5d89cba8047d518f1ffef432906ef1a3e8bd"
|
||||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
version = "0.11.0"
|
version = "0.12.0"
|
||||||
|
|
||||||
[[Statistics]]
|
[[Statistics]]
|
||||||
deps = ["LinearAlgebra", "SparseArrays"]
|
deps = ["LinearAlgebra", "SparseArrays"]
|
||||||
|
@ -5,7 +5,7 @@ version = "0.9.0"
|
|||||||
[deps]
|
[deps]
|
||||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||||
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
@ -26,11 +26,11 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
|||||||
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
|
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||||
|
|
||||||
[compat]
|
[compat]
|
||||||
CUDAapi = "1.1"
|
CUDAdrv = "4.0.1"
|
||||||
CuArrays = "1.2"
|
CuArrays = "1.4"
|
||||||
NNlib = "0.6"
|
NNlib = "0.6"
|
||||||
Zygote = "0.3"
|
Zygote = "0.3"
|
||||||
julia = "1"
|
julia = "1.2"
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
|
46
src/Flux.jl
46
src/Flux.jl
@ -21,19 +21,9 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
|||||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||||
|
|
||||||
|
|
||||||
allow_cuda() = parse(Bool, get(ENV, "FLUX_USE_CUDA", "true"))
|
ENV["CUDA_INIT_SILENT"] = true
|
||||||
const consider_cuda = allow_cuda()
|
using CUDAdrv, CuArrays
|
||||||
|
const use_cuda = Ref(false)
|
||||||
using CUDAapi
|
|
||||||
const use_cuda = consider_cuda && has_cuda()
|
|
||||||
if use_cuda
|
|
||||||
try
|
|
||||||
using CuArrays
|
|
||||||
catch
|
|
||||||
@error "CUDA is installed, but CuArrays.jl fails to load. Please fix the issue, or load Flux with FLUX_USE_CUDA=false."
|
|
||||||
rethrow()
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
@ -49,21 +39,23 @@ include("data/Data.jl")
|
|||||||
|
|
||||||
include("deprecations.jl")
|
include("deprecations.jl")
|
||||||
|
|
||||||
if use_cuda
|
|
||||||
include("cuda/cuda.jl")
|
|
||||||
end
|
|
||||||
|
|
||||||
function __init__()
|
function __init__()
|
||||||
# check if the GPU usage conditions that are baked in the precompilation image
|
if !CUDAdrv.functional()
|
||||||
# match the current situation, and force a recompilation if not.
|
@warn "CUDA available, but CUDAdrv.jl failed to load"
|
||||||
if (allow_cuda() != consider_cuda) || (consider_cuda && has_cuda() != use_cuda)
|
elseif length(devices()) == 0
|
||||||
cachefile = if VERSION >= v"1.3-"
|
@warn "CUDA available, but no GPU detected"
|
||||||
Base.compilecache_path(Base.PkgId(Flux))
|
elseif !CuArrays.functional()
|
||||||
else
|
@warn "CUDA GPU available, but CuArrays.jl failed to load"
|
||||||
abspath(DEPOT_PATH[1], Base.cache_file_entry(Base.PkgId(Flux)))
|
else
|
||||||
end
|
use_cuda[] = true
|
||||||
rm(cachefile)
|
|
||||||
error("Your set-up changed, and Flux.jl needs to be reconfigured. Please load the package again.")
|
# FIXME: this functionality should be conditional at run time by checking `use_cuda`
|
||||||
|
# (or even better, get moved to CuArrays.jl as much as possible)
|
||||||
|
if CuArrays.has_cudnn()
|
||||||
|
include(joinpath(@__DIR__, "cuda/cuda.jl"))
|
||||||
|
else
|
||||||
|
@warn "CUDA GPU available, but CuArrays.jl did not find libcudnn. Some functionality will not be available."
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -2,12 +2,8 @@ module CUDA
|
|||||||
|
|
||||||
using ..CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
if CuArrays.libcudnn !== nothing # TODO: use CuArrays.has_cudnn()
|
using CuArrays: CUDNN
|
||||||
using CuArrays: CUDNN
|
include("curnn.jl")
|
||||||
include("curnn.jl")
|
include("cudnn.jl")
|
||||||
include("cudnn.jl")
|
|
||||||
else
|
|
||||||
@warn "CUDNN is not installed, some functionality will not be available."
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -73,13 +73,7 @@ end
|
|||||||
|
|
||||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||||
|
|
||||||
const gpu_adaptor = if use_cuda
|
gpu(x) = use_cuda[] ? fmap(CuArrays.cu, x) : x
|
||||||
CuArrays.cu
|
|
||||||
else
|
|
||||||
identity
|
|
||||||
end
|
|
||||||
|
|
||||||
gpu(x) = fmap(gpu_adaptor, x)
|
|
||||||
|
|
||||||
# Precision
|
# Precision
|
||||||
|
|
||||||
|
@ -37,12 +37,10 @@ import Adapt: adapt, adapt_structure
|
|||||||
|
|
||||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
if use_cuda
|
import .CuArrays: CuArray, cudaconvert
|
||||||
import .CuArrays: CuArray, cudaconvert
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
onehot(l, labels[, unk])
|
onehot(l, labels[, unk])
|
||||||
|
@ -53,8 +53,10 @@ end
|
|||||||
@test y[3,:] isa CuArray
|
@test y[3,:] isa CuArray
|
||||||
end
|
end
|
||||||
|
|
||||||
if CuArrays.libcudnn != nothing
|
if CuArrays.has_cudnn()
|
||||||
@info "Testing Flux/CUDNN"
|
@info "Testing Flux/CUDNN"
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
include("curnn.jl")
|
include("curnn.jl")
|
||||||
|
else
|
||||||
|
@warn "CUDNN unavailable, not testing GPU DNN support"
|
||||||
end
|
end
|
||||||
|
@ -19,7 +19,7 @@ include("layers/normalisation.jl")
|
|||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/conv.jl")
|
include("layers/conv.jl")
|
||||||
|
|
||||||
if isdefined(Flux, :CUDA)
|
if Flux.use_cuda[]
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
else
|
else
|
||||||
@warn "CUDA unavailable, not testing GPU support"
|
@warn "CUDA unavailable, not testing GPU support"
|
||||||
|
Loading…
Reference in New Issue
Block a user