Replace Requires with direct CuArrays dependency.

This commit is contained in:
Tim Besard 2019-08-27 09:33:15 +02:00
parent 6494f73c78
commit 6ad3cdd138
10 changed files with 177 additions and 93 deletions

View File

@ -1,5 +1,11 @@
# This file is machine-generated - editing it directly is not advised # This file is machine-generated - editing it directly is not advised
[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "0.4.1"
[[AbstractTrees]] [[AbstractTrees]]
deps = ["Markdown", "Test"] deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
@ -7,10 +13,10 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1" version = "0.2.1"
[[Adapt]] [[Adapt]]
deps = ["LinearAlgebra", "Test"] deps = ["LinearAlgebra"]
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b" git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.2" version = "1.0.0"
[[Base64]] [[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -22,34 +28,57 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10" version = "0.8.10"
[[BinaryProvider]] [[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"] deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3" version = "0.5.6"
[[CEnum]]
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.2.0"
[[CSTParser]] [[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"] deps = ["Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56" git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2" version = "0.6.2"
[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "1.1.0"
[[CUDAdrv]]
deps = ["CUDAapi", "Libdl", "Printf"]
git-tree-sha1 = "9ce99b5732c70e06ed97c042187baed876fb1698"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "3.1.0"
[[CUDAnative]]
deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"]
git-tree-sha1 = "36cbb94f74cd3e5db774134a68dc5d033ae2c87e"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.2.1"
[[CodecZlib]] [[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193" uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.2" version = "0.6.0"
[[ColorTypes]] [[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"] deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09" git-tree-sha1 = "10050a24b09e8e41b951e9976b109871ce98d965"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5" version = "0.8.0"
[[Colors]] [[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"] deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543" git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5" version = "0.9.6"
[[CommonSubexpressions]] [[CommonSubexpressions]]
deps = ["Test"] deps = ["Test"]
@ -63,17 +92,34 @@ git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0" version = "2.1.0"
[[Conda]]
deps = ["JSON", "VersionParsing"]
git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
version = "1.3.0"
[[Crayons]] [[Crayons]]
deps = ["Test"] deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0" version = "4.0.0"
[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "1.2.1"
[[DataAPI]]
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.0.1"
[[DataStructures]] [[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0" version = "0.17.0"
[[Dates]] [[Dates]]
deps = ["Printf"] deps = ["Printf"]
@ -99,11 +145,22 @@ version = "0.0.10"
deps = ["Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "0.3.0"
[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.6.4"
[[FixedPointNumbers]] [[FixedPointNumbers]]
deps = ["Test"] git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3" version = "0.6.1"
[[ForwardDiff]] [[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
@ -111,15 +168,33 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210" uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3" version = "0.10.3"
[[GPUArrays]]
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
git-tree-sha1 = "dd169c636d1d3656a9faca772f5bd7c226a61254"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "1.0.1"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.0"
[[Juno]] [[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"] deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175" git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.7.0" version = "0.7.2"
[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "52cfea426bd248a427aace7d88eb5d45b84ea297"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.2.0"
[[LibGit2]] [[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -135,10 +210,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]] [[MacroTools]]
deps = ["CSTParser", "Compat", "DataStructures", "Test"] deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162" git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.0" version = "0.5.1"
[[Markdown]] [[Markdown]]
deps = ["Base64"] deps = ["Base64"]
@ -151,10 +226,10 @@ uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0" version = "0.5.0"
[[Missings]] [[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] deps = ["SparseArrays", "Test"]
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.0" version = "0.4.1"
[[Mmap]] [[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@ -177,6 +252,12 @@ git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0" version = "1.1.0"
[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.6"
[[Pkg]] [[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@ -239,20 +320,20 @@ uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2" version = "0.7.2"
[[StaticArrays]] [[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
uuid = "90137ffa-7385-5640-81b9-e52037218182" uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.3" version = "0.11.0"
[[Statistics]] [[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"] deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]] [[StatsBase]]
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7" git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.30.0" version = "0.32.0"
[[Test]] [[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
@ -265,22 +346,21 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0" version = "0.5.0"
[[Tokenize]] [[Tokenize]]
deps = ["Printf", "Test"] git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3" version = "0.5.6"
[[Tracker]] [[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1" git-tree-sha1 = "1aa443d3b4bfa91a8aec32f169a479cb87309910"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.0" version = "0.2.3"
[[TranscodingStreams]] [[TranscodingStreams]]
deps = ["Random", "Test"] deps = ["Random", "Test"]
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919" git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.4" version = "0.9.5"
[[URIParser]] [[URIParser]]
deps = ["Test", "Unicode"] deps = ["Test", "Unicode"]
@ -295,8 +375,14 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]] [[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[VersionParsing]]
deps = ["Compat"]
git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669"
uuid = "81def892-9a0e-5fdd-b105-ffc91e053289"
version = "1.1.3"
[[ZipFile]] [[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"] deps = ["BinaryProvider", "Libdl", "Printf"]
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e" git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.1" version = "0.8.3"

View File

@ -5,8 +5,10 @@ version = "0.8.3"
[deps] [deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@ -16,7 +18,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@ -27,6 +28,7 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
NNlib = "0.6" NNlib = "0.6"
Tracker = "0.2" Tracker = "0.2"
julia = "0.7, 1" julia = "0.7, 1"
CuArrays = "1.3"
[extras] [extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

View File

@ -3,7 +3,7 @@ module Flux
# Zero Flux Given # Zero Flux Given
using Base: tail using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
@ -24,6 +24,17 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay ADAMW, InvDecay, ExpDecay, WeightDecay
using CUDAapi
if has_cuda()
try
using CuArrays
@eval has_cuarrays() = true
catch ex
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())
@eval has_cuarrays() = false
end
end
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")
include("treelike.jl") include("treelike.jl")
@ -36,6 +47,8 @@ include("layers/normalise.jl")
include("data/Data.jl") include("data/Data.jl")
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl") if has_cuarrays()
include("cuda/cuda.jl")
end
end # module end # module

View File

@ -1,38 +1,12 @@
module CUDA module CUDA
using ..CuArrays using ..CuArrays
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
using Pkg.TOML
function version_check() if has_cudnn()
major_version = 1
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0"))
if version.major != major_version
@warn """
Flux is only supported with CuArrays v$major_version.x.
Try running `] pin CuArrays@$major_version`.
"""
end
end
version_check()
if !applicable(CuArray{UInt8}, undef, 1)
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
end
if CuArrays.libcudnn != nothing
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
include("curnn.jl") include("curnn.jl")
include("cudnn.jl") include("cudnn.jl")
else else
@warn("CUDNN is not installed, some functionality will not be available.") @warn "CUDNN is not installed, some functionality will not be available."
end end
end end

View File

@ -1,9 +1,13 @@
using .CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t, using CuArrays: libcudnn
using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
using .CuArrays: libcudnn
import ..Flux: data import CuArrays.CUDAdrv: CuPtr, CU_NULL
using LinearAlgebra using LinearAlgebra
import ..Flux: data
mutable struct DropoutDesc mutable struct DropoutDesc
ptr::Ptr{Nothing} ptr::Ptr{Nothing}
states::CuVector{UInt8} states::CuVector{UInt8}

View File

@ -1,6 +1,9 @@
using .CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t, using CuArrays: libcudnn
using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
using .CuArrays: libcudnn
import CuArrays.CUDAdrv: CuPtr, CU_NULL
using LinearAlgebra using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation const RNN_RELU = 0 # Stock RNN with ReLu activation
@ -223,8 +226,8 @@ end
import ..Flux: Flux, relu import ..Flux: Flux, relu
import ..Tracker: TrackedArray import ..Tracker: TrackedArray
using .CuArrays.CUDAnative using CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims using CuArrays: @cuindex, cudims
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray) function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src) function kernel(dst, src)

View File

@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin if has_cuarrays()
import .CuArrays: CuArray, cudaconvert import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}() BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()

View File

@ -60,10 +60,10 @@ end
cpu(m) = mapleaves(x -> adapt(Array, x), m) cpu(m) = mapleaves(x -> adapt(Array, x), m)
gpu_adaptor = identity const gpu_adaptor = if has_cuarrays()
CuArrays.cu
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin else
global gpu_adaptor = CuArrays.cu identity
end end
gpu(x) = mapleaves(gpu_adaptor, x) gpu(x) = mapleaves(gpu_adaptor, x)

View File

@ -48,7 +48,7 @@ end
@test y[3,:] isa CuArray @test y[3,:] isa CuArray
end end
if CuArrays.libcudnn != nothing if has_cudnn() != nothing
@info "Testing Flux/CUDNN" @info "Testing Flux/CUDNN"
include("cudnn.jl") include("cudnn.jl")
if !haskey(ENV, "CI_DISABLE_CURNN_TEST") if !haskey(ENV, "CI_DISABLE_CURNN_TEST")

View File

@ -26,8 +26,10 @@ include("layers/conv.jl")
include("tracker.jl") include("tracker.jl")
if Base.find_package("CuArrays") != nothing if isdefined(Flux, :CUDA)
include("cuda/cuda.jl") include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end end
end end