commit
9279d79e63
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1 +1,2 @@
|
||||
paper/* linguist-documentation
|
||||
CITATION.bib linguist-detectable=false
|
||||
|
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@ -0,0 +1 @@
|
||||
custom: https://numfocus.salsalabs.org/donate-to-julia/index.html
|
@ -1,37 +1,41 @@
|
||||
before_script:
|
||||
- export CI_DISABLE_CURNN_TEST=true
|
||||
|
||||
variables:
|
||||
CI_IMAGE_TAG: 'cuda'
|
||||
|
||||
include:
|
||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml'
|
||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v6.yml'
|
||||
|
||||
.flux:
|
||||
extends: .test
|
||||
script:
|
||||
- julia -e 'using InteractiveUtils;
|
||||
versioninfo()'
|
||||
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
|
||||
- julia -e 'using Pkg;
|
||||
Pkg.add("CuArrays");'
|
||||
- julia --project -e 'using Pkg;
|
||||
Pkg.instantiate();
|
||||
Pkg.build();
|
||||
Pkg.test(; coverage=true);'
|
||||
image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
|
||||
|
||||
test:v1.0:
|
||||
extends: .flux
|
||||
variables:
|
||||
CI_VERSION_TAG: 'v1.0'
|
||||
only:
|
||||
- staging
|
||||
- trying
|
||||
|
||||
test:v1.1:
|
||||
extends: .flux
|
||||
variables:
|
||||
CI_VERSION_TAG: 'v1.1'
|
||||
only:
|
||||
- staging
|
||||
- trying
|
||||
julia:1.0:
|
||||
extends:
|
||||
- .julia:1.0
|
||||
- .test
|
||||
tags:
|
||||
- nvidia
|
||||
|
||||
julia:1.1:
|
||||
extends:
|
||||
- .julia:1.1
|
||||
- .test
|
||||
tags:
|
||||
- nvidia
|
||||
|
||||
julia:1.2:
|
||||
extends:
|
||||
- .julia:1.2
|
||||
- .test
|
||||
tags:
|
||||
- nvidia
|
||||
|
||||
julia:1.3:
|
||||
extends:
|
||||
- .julia:1.3
|
||||
- .test
|
||||
tags:
|
||||
- nvidia
|
||||
|
||||
julia:nightly:
|
||||
extends:
|
||||
- .julia:nightly
|
||||
- .test
|
||||
tags:
|
||||
- nvidia
|
||||
allow_failure: true
|
||||
|
@ -7,6 +7,8 @@ os:
|
||||
|
||||
julia:
|
||||
- 1.0
|
||||
- 1.2
|
||||
- 1.3
|
||||
- nightly
|
||||
|
||||
matrix:
|
||||
@ -16,7 +18,7 @@ matrix:
|
||||
jobs:
|
||||
include:
|
||||
- stage: "Documentation"
|
||||
julia: 1.0
|
||||
julia: 1.2
|
||||
os: linux
|
||||
script:
|
||||
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
|
||||
|
244
Manifest.toml
244
Manifest.toml
@ -1,4 +1,8 @@
|
||||
# This file is machine-generated - editing it directly is not advised
|
||||
[[AbstractFFTs]]
|
||||
deps = ["LinearAlgebra"]
|
||||
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
|
||||
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
|
||||
version = "0.4.1"
|
||||
|
||||
[[AbstractTrees]]
|
||||
deps = ["Markdown", "Test"]
|
||||
@ -7,10 +11,10 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
version = "0.2.1"
|
||||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
|
||||
deps = ["LinearAlgebra"]
|
||||
git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "0.4.2"
|
||||
version = "1.0.0"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
@ -22,34 +26,51 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
|
||||
version = "0.8.10"
|
||||
|
||||
[[BinaryProvider]]
|
||||
deps = ["Libdl", "Pkg", "SHA", "Test"]
|
||||
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
||||
deps = ["Libdl", "SHA"]
|
||||
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
|
||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.3"
|
||||
version = "0.5.8"
|
||||
|
||||
[[CSTParser]]
|
||||
deps = ["LibGit2", "Test", "Tokenize"]
|
||||
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
|
||||
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
||||
version = "0.5.2"
|
||||
[[CEnum]]
|
||||
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
|
||||
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
|
||||
version = "0.2.0"
|
||||
|
||||
[[CUDAapi]]
|
||||
deps = ["Libdl", "Logging"]
|
||||
git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3"
|
||||
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||
version = "2.0.0"
|
||||
|
||||
[[CUDAdrv]]
|
||||
deps = ["CEnum", "CUDAapi", "Printf"]
|
||||
git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921"
|
||||
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||
version = "4.0.4"
|
||||
|
||||
[[CUDAnative]]
|
||||
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
|
||||
git-tree-sha1 = "93f6c917ab2a9b5bb54f8f738f4ec1a6693cb716"
|
||||
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
||||
version = "2.5.5"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
||||
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
|
||||
git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.5.2"
|
||||
version = "0.6.0"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
|
||||
deps = ["FixedPointNumbers", "Random"]
|
||||
git-tree-sha1 = "10050a24b09e8e41b951e9976b109871ce98d965"
|
||||
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
||||
version = "0.7.5"
|
||||
version = "0.8.0"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
|
||||
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
|
||||
git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.9.5"
|
||||
version = "0.9.6"
|
||||
|
||||
[[CommonSubexpressions]]
|
||||
deps = ["Test"]
|
||||
@ -59,21 +80,34 @@ version = "0.2.0"
|
||||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
||||
git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "2.1.0"
|
||||
version = "2.2.0"
|
||||
|
||||
[[Crayons]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
||||
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||
version = "4.0.0"
|
||||
[[Conda]]
|
||||
deps = ["JSON", "VersionParsing"]
|
||||
git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
|
||||
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
|
||||
version = "1.3.0"
|
||||
|
||||
[[CuArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||
git-tree-sha1 = "7e00178b18672ee2cf37244ac2a273b6b0701b04"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
version = "1.4.7"
|
||||
|
||||
[[DataAPI]]
|
||||
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
|
||||
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
|
||||
version = "1.1.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
|
||||
deps = ["InteractiveUtils", "OrderedCollections"]
|
||||
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.15.0"
|
||||
version = "0.17.6"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
@ -90,36 +124,71 @@ uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "0.0.4"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
||||
deps = ["NaNMath", "Random", "SpecialFunctions"]
|
||||
git-tree-sha1 = "f734b5f6bc9c909027ef99f6d91d5d9e4b111eed"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "0.0.10"
|
||||
version = "0.1.0"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[FFTW]]
|
||||
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
|
||||
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
|
||||
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
|
||||
version = "1.0.1"
|
||||
|
||||
[[FillArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
||||
git-tree-sha1 = "1a9fe4e1323f38de0ba4da49eafd15b25ec62298"
|
||||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
||||
version = "0.8.2"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
|
||||
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
|
||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
||||
version = "0.5.3"
|
||||
version = "0.6.1"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
|
||||
git-tree-sha1 = "da46ac97b17793eba44ff366dc6cb70f1238a738"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.3"
|
||||
version = "0.10.7"
|
||||
|
||||
[[GPUArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
|
||||
git-tree-sha1 = "a0a3b927b1a06e63fb8b91950cc7df340b7d912c"
|
||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||
version = "2.0.0"
|
||||
|
||||
[[IRTools]]
|
||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||
git-tree-sha1 = "72421971e60917b8cd7737f9577c4f0f87eab306"
|
||||
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||
version = "0.3.0"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[JSON]]
|
||||
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
|
||||
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
|
||||
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
|
||||
version = "0.21.0"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
|
||||
git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.7.0"
|
||||
version = "0.7.2"
|
||||
|
||||
[[LLVM]]
|
||||
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
|
||||
git-tree-sha1 = "74fe444b8b6d1ac01d639b2f9eaf395bcc2e24fc"
|
||||
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
|
||||
version = "1.3.2"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
@ -135,10 +204,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
|
||||
[[MacroTools]]
|
||||
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
|
||||
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
|
||||
deps = ["Compat", "DataStructures", "Test"]
|
||||
git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.5.0"
|
||||
version = "0.5.2"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
@ -151,10 +220,10 @@ uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Missings]]
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
|
||||
deps = ["DataAPI"]
|
||||
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.4.0"
|
||||
version = "0.4.3"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
@ -166,10 +235,9 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.6.0"
|
||||
|
||||
[[NaNMath]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
|
||||
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
|
||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
version = "0.3.2"
|
||||
version = "0.3.3"
|
||||
|
||||
[[OrderedCollections]]
|
||||
deps = ["Random", "Serialization", "Test"]
|
||||
@ -177,6 +245,12 @@ git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.1.0"
|
||||
|
||||
[[Parsers]]
|
||||
deps = ["Dates", "Test"]
|
||||
git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556"
|
||||
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
||||
version = "0.3.10"
|
||||
|
||||
[[Pkg]]
|
||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
@ -233,54 +307,42 @@ deps = ["LinearAlgebra", "Random"]
|
||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
||||
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
|
||||
deps = ["BinDeps", "BinaryProvider", "Libdl"]
|
||||
git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.7.2"
|
||||
version = "0.8.0"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
|
||||
deps = ["LinearAlgebra", "Random", "Statistics"]
|
||||
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.10.3"
|
||||
version = "0.12.1"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
|
||||
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||
git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.30.0"
|
||||
version = "0.32.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TimerOutputs]]
|
||||
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
||||
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
||||
deps = ["Printf"]
|
||||
git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03"
|
||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Tokenize]]
|
||||
deps = ["Printf", "Test"]
|
||||
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
|
||||
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
||||
version = "0.5.3"
|
||||
|
||||
[[Tracker]]
|
||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
|
||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
version = "0.2.0"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
|
||||
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.9.4"
|
||||
version = "0.9.5"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
@ -289,14 +351,32 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
||||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random", "SHA"]
|
||||
deps = ["Random"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||
|
||||
[[VersionParsing]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669"
|
||||
uuid = "81def892-9a0e-5fdd-b105-ffc91e053289"
|
||||
version = "1.1.3"
|
||||
|
||||
[[ZipFile]]
|
||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
||||
deps = ["BinaryProvider", "Libdl", "Printf"]
|
||||
git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.8.1"
|
||||
version = "0.8.3"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.4.1"
|
||||
|
||||
[[ZygoteRules]]
|
||||
deps = ["MacroTools"]
|
||||
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
|
||||
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||
version = "0.2.0"
|
||||
|
14
NEWS.md
14
NEWS.md
@ -1,6 +1,20 @@
|
||||
# v0.10.0
|
||||
* The default AD engine has switched from [Tracker to Zygote.jl](https://github.com/FluxML/Flux.jl/pull/669)
|
||||
- The dependency on Tracker.jl has been removed.
|
||||
- This means Flux now does not depend on using a specialised `TrackedArray` type, and can be used with normal Array implementations directly.
|
||||
- Tracker compatibility is maintained in most common cases, but Zygote will be the preferred AD backend for Flux from now on.
|
||||
* The CUDNN wrappers have been [moved from Flux into CuArrays](https://github.com/FluxML/Flux.jl/pull/874), to allow for better supporting the CUDA backend, and improve user experience, not to mention making Flux lean.
|
||||
* `*crossentropy` functions now [work as expected with CuArrays](https://github.com/FluxML/Flux.jl/pull/926). [PR for binarycrossentropy](https://github.com/FluxML/Flux.jl/pull/940).
|
||||
* Added [clearer docs](https://github.com/FluxML/Flux.jl/pull/904) around training and the Optimiser interface.
|
||||
* [Layer initialisations](https://github.com/FluxML/Flux.jl/pull/937) have been improved with a clearer API on how to extend it for other purposes.
|
||||
* [Better messaging around CUDA availability](https://github.com/FluxML/Flux.jl/pull/924), with hooks to initialize the GPU as default where possible.
|
||||
* `@treelike` has been formalised as a [functor](https://github.com/FluxML/Flux.jl/pull/865), with an effective deprecation.
|
||||
* `testmode!` is deprecated in favour of [istraining](https://github.com/FluxML/Flux.jl/pull/669)
|
||||
|
||||
# v0.9.0
|
||||
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
|
||||
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
|
||||
* New [RADAM](https://github.com/FluxML/Flux.jl/pull/842) optimiser.
|
||||
|
||||
# v0.8.0
|
||||
|
||||
|
25
Project.toml
25
Project.toml
@ -1,35 +1,46 @@
|
||||
name = "Flux"
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.8.3"
|
||||
version = "0.10.0"
|
||||
|
||||
[deps]
|
||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
||||
[compat]
|
||||
AbstractTrees = "0.2"
|
||||
Adapt = "1"
|
||||
CodecZlib = "0.5, 0.6"
|
||||
Colors = "0.8, 0.9"
|
||||
CuArrays = "1.4.3"
|
||||
Juno = "0.5, 0.6, 0.7"
|
||||
MacroTools = "0.3, 0.4, 0.5"
|
||||
NNlib = "0.6"
|
||||
Tracker = "0.2"
|
||||
julia = "0.7, 1"
|
||||
Reexport = "0.2"
|
||||
StatsBase = "0"
|
||||
ZipFile = "0.7, 0.8"
|
||||
Zygote = "0.4"
|
||||
julia = "1"
|
||||
|
||||
[extras]
|
||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[targets]
|
||||
test = ["Test"]
|
||||
test = ["Test", "Documenter"]
|
||||
|
88
README.md
88
README.md
@ -7,93 +7,9 @@
|
||||
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
||||
|
||||
```julia
|
||||
julia> Pkg.add("Flux")
|
||||
] add Flux
|
||||
```
|
||||
|
||||
See the [documentation](https://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||
|
||||
If you use Flux in research, please cite the following paper:
|
||||
|
||||
```
|
||||
@article{innes:2018,
|
||||
author = {Mike Innes},
|
||||
title = {Flux: Elegant Machine Learning with Julia},
|
||||
journal = {Journal of Open Source Software},
|
||||
year = {2018},
|
||||
doi = {10.21105/joss.00602},
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
Flux has powerful high-level features, and common architectures can be defined in a few lines.
|
||||
|
||||
```julia
|
||||
model = Chain(
|
||||
Dense(768, 128, σ),
|
||||
LSTM(128, 256),
|
||||
LSTM(256, 128),
|
||||
Dense(128, 10),
|
||||
softmax)
|
||||
|
||||
loss(x, y) = crossentropy(model(x), y)
|
||||
|
||||
Flux.train!(loss, data, ADAM(...))
|
||||
```
|
||||
|
||||
Yet you can easily strip away the layers, and directly write the mathematics for your problem. Flux will seamlessly take gradients of any Julia code, so your model looks just like the paper.
|
||||
|
||||
```julia
|
||||
W = param(randn(2, 10))
|
||||
b = param(randn(2))
|
||||
|
||||
y(x) = σ.(W * x .+ b)
|
||||
```
|
||||
|
||||
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels with [CUDAnative](https://github.com/JuliaGPU/CUDAnative.jl)! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
|
||||
|
||||
```julia
|
||||
function gpu_add(a, b, c)
|
||||
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
|
||||
c[i] = a[i] + b[i]
|
||||
return nothing
|
||||
end
|
||||
```
|
||||
|
||||
Unusual architectures are no problem in Flux, as you can use all the loops, control flow and even macros that you're used to. Here's a Tree RNN in 4 lines.
|
||||
|
||||
```julia
|
||||
tree() = rand() < 0.5 ? rand(10) : (tree(), tree()) # dummy data
|
||||
|
||||
shrink = Dense(20, 10)
|
||||
combine(a, b) = shrink([a; b])
|
||||
|
||||
model(x) = x
|
||||
model(x::Tuple) = combine(model(x[1]), model(x[2]))
|
||||
|
||||
model(tree()) # Sample output
|
||||
```
|
||||
|
||||
Despite this flexibility, Julia's advanced compiler lets us do some powerful optimisations. For example, this definition of `sigmoid` automatically gets fused into a *single* GPU kernel – so it's really fast.
|
||||
|
||||
```julia
|
||||
sigmoid(xs) = 1 ./ (1 .+ exp.(.-xs))
|
||||
```
|
||||
|
||||
Similarly, Flux is the first dynamic framework to support [compiling to the browser](https://fluxml.github.io/experiments/) and model import via [formats like ONNX](https://github.com/FluxML/ONNX.jl/), both of which are thinly-veiled compiler problems.
|
||||
|
||||
For more on our philosophy on machine learning, check out our article [On Machine Learning & Programming Languages](https://julialang.org/blog/2017/12/ml&pl).
|
||||
|
||||
## Contributing & Help
|
||||
|
||||
For general questions and help, check out Julia's [community forum](https://discourse.julialang.org/c/domain/ML).
|
||||
|
||||
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
|
||||
|
||||
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
|
||||
|
||||
## Related Packages
|
||||
|
||||
Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
|
||||
|
||||
[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.
|
||||
If you use Flux in research, please see [our papers](CITATION.bib) for appropriate citations.
|
||||
|
13
REQUIRE
13
REQUIRE
@ -1,13 +0,0 @@
|
||||
julia 1.0
|
||||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
Requires
|
||||
Adapt 0.4
|
||||
CodecZlib
|
||||
Colors
|
||||
ZipFile
|
||||
AbstractTrees
|
||||
Reexport
|
||||
StatsBase
|
||||
Tracker
|
@ -1,205 +1,56 @@
|
||||
# This file is machine-generated - editing it directly is not advised
|
||||
|
||||
[[AbstractTrees]]
|
||||
deps = ["Markdown", "Test"]
|
||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
version = "0.2.1"
|
||||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "0.4.2"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
||||
[[BinDeps]]
|
||||
deps = ["Compat", "Libdl", "SHA", "URIParser"]
|
||||
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
|
||||
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
|
||||
version = "0.8.10"
|
||||
|
||||
[[BinaryProvider]]
|
||||
deps = ["Libdl", "Pkg", "SHA", "Test"]
|
||||
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.3"
|
||||
|
||||
[[CSTParser]]
|
||||
deps = ["LibGit2", "Test", "Tokenize"]
|
||||
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
|
||||
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
||||
version = "0.5.2"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.5.2"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
|
||||
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
||||
version = "0.7.5"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
|
||||
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.9.5"
|
||||
|
||||
[[CommonSubexpressions]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
|
||||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "2.1.0"
|
||||
|
||||
[[Crayons]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
||||
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||
version = "4.0.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.15.0"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||
|
||||
[[DelimitedFiles]]
|
||||
deps = ["Mmap"]
|
||||
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
|
||||
[[DiffResults]]
|
||||
deps = ["Compat", "StaticArrays"]
|
||||
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
|
||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "0.0.4"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "0.0.10"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[DocStringExtensions]]
|
||||
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
|
||||
git-tree-sha1 = "4d30e889c9f106a51ffa4791a88ffd4765bf20c3"
|
||||
git-tree-sha1 = "0513f1a8991e9d83255e0140aace0d0fc4486600"
|
||||
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
|
||||
version = "0.7.0"
|
||||
version = "0.8.0"
|
||||
|
||||
[[Documenter]]
|
||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
||||
git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0"
|
||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
|
||||
git-tree-sha1 = "c61d6eedbc3c4323c08b64af12d29c8ee0fcbb5f"
|
||||
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
version = "0.22.3"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
|
||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
||||
version = "0.5.3"
|
||||
|
||||
[[Flux]]
|
||||
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"]
|
||||
path = ".."
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.8.2+"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.3"
|
||||
version = "0.23.2"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[JSON]]
|
||||
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
|
||||
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
|
||||
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
|
||||
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
|
||||
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
|
||||
version = "0.20.0"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.7.0"
|
||||
version = "0.21.0"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
||||
[[Libdl]]
|
||||
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
|
||||
|
||||
[[LinearAlgebra]]
|
||||
deps = ["Libdl"]
|
||||
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
|
||||
[[Logging]]
|
||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
|
||||
[[MacroTools]]
|
||||
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
|
||||
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
||||
|
||||
[[Media]]
|
||||
deps = ["MacroTools", "Test"]
|
||||
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
|
||||
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Missings]]
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.4.0"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
|
||||
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.6.0"
|
||||
|
||||
[[NaNMath]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
|
||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
version = "0.3.2"
|
||||
|
||||
[[OrderedCollections]]
|
||||
deps = ["Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.1.0"
|
||||
[[Parsers]]
|
||||
deps = ["Dates", "Test"]
|
||||
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
|
||||
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
||||
version = "0.3.6"
|
||||
|
||||
[[Pkg]]
|
||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
@ -209,10 +60,6 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
deps = ["Unicode"]
|
||||
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
|
||||
[[Profile]]
|
||||
deps = ["Printf"]
|
||||
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
|
||||
|
||||
[[REPL]]
|
||||
deps = ["InteractiveUtils", "Markdown", "Sockets"]
|
||||
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
||||
@ -221,106 +68,22 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
||||
deps = ["Serialization"]
|
||||
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
|
||||
[[Reexport]]
|
||||
deps = ["Pkg"]
|
||||
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
|
||||
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Requires]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
|
||||
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
version = "0.5.2"
|
||||
|
||||
[[SHA]]
|
||||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
|
||||
[[Serialization]]
|
||||
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
||||
|
||||
[[SharedArrays]]
|
||||
deps = ["Distributed", "Mmap", "Random", "Serialization"]
|
||||
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
|
||||
|
||||
[[Sockets]]
|
||||
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
|
||||
|
||||
[[SortingAlgorithms]]
|
||||
deps = ["DataStructures", "Random", "Test"]
|
||||
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
|
||||
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
|
||||
version = "0.3.1"
|
||||
|
||||
[[SparseArrays]]
|
||||
deps = ["LinearAlgebra", "Random"]
|
||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
||||
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.7.2"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.10.3"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.30.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TimerOutputs]]
|
||||
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
||||
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Tokenize]]
|
||||
deps = ["Printf", "Test"]
|
||||
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
|
||||
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
||||
version = "0.5.3"
|
||||
|
||||
[[Tracker]]
|
||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
|
||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
version = "0.2.0"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.9.4"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
|
||||
uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
||||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random", "SHA"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||
|
||||
[[ZipFile]]
|
||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.8.1"
|
||||
|
@ -1,4 +1,2 @@
|
||||
[deps]
|
||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
|
18
docs/make.jl
18
docs/make.jl
@ -1,12 +1,13 @@
|
||||
using Pkg;
|
||||
Pkg.activate(joinpath(@__DIR__, "..")); Pkg.instantiate()
|
||||
Pkg.activate(); Pkg.instantiate()
|
||||
|
||||
pushfirst!(LOAD_PATH, joinpath(@__DIR__, ".."))
|
||||
|
||||
using Documenter, Flux, NNlib
|
||||
|
||||
makedocs(modules=[Flux, NNlib],
|
||||
doctest = true,
|
||||
analytics = "UA-36890222-9",
|
||||
sitename = "Flux",
|
||||
# Uncomment below for local build
|
||||
#format = Documenter.HTML(prettyurls = false),
|
||||
assets = ["assets/flux.css"],
|
||||
pages = ["Home" => "index.md",
|
||||
"Building Models" =>
|
||||
["Basics" => "models/basics.md",
|
||||
@ -20,8 +21,9 @@ makedocs(modules=[Flux, NNlib],
|
||||
"GPU Support" => "gpu.md",
|
||||
"Saving & Loading" => "saving.md",
|
||||
"Performance Tips" => "performance.md",
|
||||
"Internals" =>
|
||||
["Backpropagation" => "internals/tracker.md"],
|
||||
"Community" => "community.md"])
|
||||
"Community" => "community.md"],
|
||||
format = Documenter.HTML(assets = ["assets/flux.css"],
|
||||
analytics = "UA-36890222-9",
|
||||
prettyurls = haskey(ENV, "CI")))
|
||||
|
||||
deploydocs(repo = "github.com/FluxML/Flux.jl.git")
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Community
|
||||
|
||||
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby). If you have questions or issues we'll try to help you out.
|
||||
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), or the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning). If you have questions or issues we'll try to help you out.
|
||||
|
||||
If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started.
|
||||
|
@ -1,14 +1,6 @@
|
||||
# GPU Support
|
||||
|
||||
## Installation
|
||||
|
||||
To get GPU support for NVIDIA graphics cards, you need to install `CuArrays.jl`
|
||||
|
||||
**Steps needed**
|
||||
|
||||
1. Install [NVIDIA toolkit](https://developer.nvidia.com/cuda-downloads)
|
||||
2. Install [NVIDIA cuDNN library](https://developer.nvidia.com/cudnn)
|
||||
3. In Julia's terminal run `]add CuArrays`
|
||||
NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CuArrays](https://github.com/JuliaGPU/CuArrays.jl) readme.
|
||||
|
||||
## GPU Usage
|
||||
|
||||
@ -33,16 +25,16 @@ loss(x, y) # ~ 3
|
||||
|
||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
|
||||
|
||||
```julia
|
||||
d = Dense(10, 5, σ)
|
||||
d = mapleaves(cu, d)
|
||||
d = fmap(cu, d)
|
||||
d.W # Tracked CuArray
|
||||
d(cu(rand(10))) # CuArray output
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
m = mapleaves(cu, m)
|
||||
m = fmap(cu, m)
|
||||
d(cu(rand(10)))
|
||||
```
|
||||
|
||||
|
@ -1,184 +0,0 @@
|
||||
# Flux.Tracker
|
||||
|
||||
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
|
||||
|
||||
```julia
|
||||
julia> using Flux.Tracker
|
||||
```
|
||||
|
||||
Here we discuss some more advanced uses of this module, as well as covering its internals.
|
||||
|
||||
## Taking Gradients
|
||||
|
||||
In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
|
||||
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: forward
|
||||
|
||||
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
|
||||
|
||||
back(1) # (3.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
|
||||
|
||||
```julia
|
||||
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
|
||||
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
|
||||
|
||||
julia> back([1,1,1])
|
||||
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
|
||||
```
|
||||
|
||||
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
|
||||
|
||||
```julia
|
||||
a, b = param(2), param(3)
|
||||
|
||||
c = a*b # 6.0 (tracked)
|
||||
|
||||
Tracker.back!(c)
|
||||
|
||||
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
|
||||
```
|
||||
|
||||
## Tracked Arrays
|
||||
|
||||
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
|
||||
|
||||
```julia
|
||||
julia> W = param([1 2; 3 4])
|
||||
Tracked 2×2 Array{Float64,2}:
|
||||
1.0 2.0
|
||||
3.0 4.0
|
||||
|
||||
julia> x = param([5, 6])
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
5.0
|
||||
6.0
|
||||
|
||||
julia> y = W*x
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
17.0
|
||||
39.0
|
||||
```
|
||||
|
||||
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
|
||||
|
||||
```julia
|
||||
julia> Tracker.back!(y, [1, -1])
|
||||
|
||||
julia> W.grad
|
||||
2×2 Array{Float64,2}:
|
||||
5.0 6.0
|
||||
-5.0 -6.0
|
||||
|
||||
julia> x.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`.
|
||||
|
||||
## Custom Gradients
|
||||
|
||||
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
|
||||
|
||||
```julia
|
||||
minus(a, b) = a - b
|
||||
```
|
||||
|
||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: TrackedArray, track, @grad
|
||||
|
||||
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
|
||||
```
|
||||
|
||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
||||
|
||||
```julia
|
||||
@grad function minus(a, b)
|
||||
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
|
||||
end
|
||||
```
|
||||
|
||||
This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress).
|
||||
|
||||
Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this.
|
||||
|
||||
```julia
|
||||
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
||||
```
|
||||
|
||||
We can then calculate the first derivative of `minus` as follows:
|
||||
|
||||
```julia
|
||||
a = param([1,2,3])
|
||||
b = param([3,2,1])
|
||||
|
||||
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
|
||||
|
||||
Tracker.back!(c, 1)
|
||||
Tracker.grad(a) # [1.00, 1.00, 1.00]
|
||||
Tracker.grad(b) # [-1.00, -1.00, -1.00]
|
||||
```
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
## Tracked Internals
|
||||
|
||||
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
|
||||
|
||||
```julia
|
||||
julia> x.tracker
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
```
|
||||
|
||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
||||
|
||||
```julia
|
||||
julia> x.tracker.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
|
||||
|
||||
```julia
|
||||
julia> Tracker.Call(+, 1, 2)
|
||||
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
|
||||
```
|
||||
|
||||
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
|
||||
|
||||
```julia
|
||||
julia> y.tracker.f
|
||||
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
|
||||
```
|
||||
|
||||
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
|
||||
|
||||
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
|
||||
|
||||
```julia
|
||||
Tracker.back(*, [1, -1], W, x)
|
||||
```
|
||||
|
||||
which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
|
@ -5,55 +5,56 @@
|
||||
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
||||
|
||||
```jldoctest basics
|
||||
julia> using Flux.Tracker
|
||||
julia> using Flux
|
||||
|
||||
julia> f(x) = 3x^2 + 2x + 1;
|
||||
|
||||
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
|
||||
julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
|
||||
|
||||
julia> df(2)
|
||||
14.0 (tracked)
|
||||
14
|
||||
|
||||
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
|
||||
julia> d2f(x) = gradient(df, x)[1]; # d²f/dx² = 6
|
||||
|
||||
julia> d2f(2)
|
||||
6.0 (tracked)
|
||||
6
|
||||
```
|
||||
|
||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||
|
||||
When a function has many parameters, we can pass them all in explicitly:
|
||||
When a function has many parameters, we can get gradients of each one at the same time:
|
||||
|
||||
```jldoctest basics
|
||||
julia> f(W, b, x) = W * x + b;
|
||||
julia> f(x, y) = sum((x .- y).^2);
|
||||
|
||||
julia> Tracker.gradient(f, 2, 3, 4)
|
||||
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
||||
julia> gradient(f, [2, 1], [2, 0])
|
||||
([0, 2], [0, -2])
|
||||
```
|
||||
|
||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
||||
But machine learning models can have *hundreds* of parameters! To handle this, Flux lets you work with collections of parameters, via `params`. You can get the gradient of all parameters used in a program without explicitly passing them in.
|
||||
|
||||
```jldoctest basics
|
||||
julia> using Flux
|
||||
|
||||
julia> W = param(2)
|
||||
2.0 (tracked)
|
||||
julia> x = [2, 1];
|
||||
|
||||
julia> b = param(3)
|
||||
3.0 (tracked)
|
||||
julia> y = [2, 0];
|
||||
|
||||
julia> f(x) = W * x + b;
|
||||
julia> gs = gradient(params(x, y)) do
|
||||
f(x, y)
|
||||
end
|
||||
Grads(...)
|
||||
|
||||
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
|
||||
julia> gs[x]
|
||||
2-element Array{Int64,1}:
|
||||
0
|
||||
2
|
||||
|
||||
julia> grads[W]
|
||||
4.0 (tracked)
|
||||
|
||||
julia> grads[b]
|
||||
1.0 (tracked)
|
||||
julia> gs[y]
|
||||
2-element Array{Int64,1}:
|
||||
0
|
||||
-2
|
||||
```
|
||||
|
||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||
Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||
|
||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||
|
||||
@ -76,26 +77,20 @@ x, y = rand(5), rand(2) # Dummy data
|
||||
loss(x, y) # ~ 3
|
||||
```
|
||||
|
||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above.
|
||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
using Flux
|
||||
|
||||
W = param(W)
|
||||
b = param(b)
|
||||
|
||||
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
|
||||
gs = gradient(() -> loss(x, y), params(W, b))
|
||||
```
|
||||
|
||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
||||
Now that we have gradients, we can pull them out and update `W` to train the model.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: update!
|
||||
W̄ = gs[W]
|
||||
|
||||
Δ = gs[W]
|
||||
|
||||
# Update the parameter and reset the gradient
|
||||
update!(W, -0.1Δ)
|
||||
W .-= 0.1 .* W̄
|
||||
|
||||
loss(x, y) # ~ 2.5
|
||||
```
|
||||
@ -111,12 +106,12 @@ It's common to create more complex models than the linear regression above. For
|
||||
```julia
|
||||
using Flux
|
||||
|
||||
W1 = param(rand(3, 5))
|
||||
b1 = param(rand(3))
|
||||
W1 = rand(3, 5)
|
||||
b1 = rand(3)
|
||||
layer1(x) = W1 * x .+ b1
|
||||
|
||||
W2 = param(rand(2, 3))
|
||||
b2 = param(rand(2))
|
||||
W2 = rand(2, 3)
|
||||
b2 = rand(2)
|
||||
layer2(x) = W2 * x .+ b2
|
||||
|
||||
model(x) = layer2(σ.(layer1(x)))
|
||||
@ -128,8 +123,8 @@ This works but is fairly unwieldy, with a lot of repetition – especially as we
|
||||
|
||||
```julia
|
||||
function linear(in, out)
|
||||
W = param(randn(out, in))
|
||||
b = param(randn(out))
|
||||
W = randn(out, in)
|
||||
b = randn(out)
|
||||
x -> W * x .+ b
|
||||
end
|
||||
|
||||
@ -150,7 +145,7 @@ struct Affine
|
||||
end
|
||||
|
||||
Affine(in::Integer, out::Integer) =
|
||||
Affine(param(randn(out, in)), param(randn(out)))
|
||||
Affine(randn(out, in), randn(out))
|
||||
|
||||
# Overload call, so the object can be used as a function
|
||||
(m::Affine)(x) = m.W * x .+ m.b
|
||||
@ -220,7 +215,7 @@ m(5) # => 26
|
||||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||
|
||||
```julia
|
||||
Flux.@treelike Affine
|
||||
Flux.@functor Affine
|
||||
```
|
||||
|
||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||
|
@ -59,7 +59,6 @@ swish
|
||||
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||
|
||||
```@docs
|
||||
Flux.testmode!
|
||||
BatchNorm
|
||||
Dropout
|
||||
AlphaDropout
|
||||
|
@ -101,26 +101,4 @@ m = Chain(LSTM(10, 15), Dense(15, 5))
|
||||
m.(seq)
|
||||
```
|
||||
|
||||
## Truncating Gradients
|
||||
|
||||
By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.
|
||||
|
||||
To avoid this we can *truncate* the gradient calculation, forgetting the history.
|
||||
|
||||
```julia
|
||||
truncate!(m)
|
||||
```
|
||||
|
||||
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
||||
|
||||
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
||||
|
||||
In general, when training with recurrent layers in your model, you'll want to call `reset!` or `truncate!` for each loss calculation:
|
||||
|
||||
```julia
|
||||
function loss(x,y)
|
||||
l = Flux.mse(m(x), y)
|
||||
Flux.reset!(m)
|
||||
return l
|
||||
end
|
||||
```
|
||||
Finally, we can reset the hidden state of the cell back to its initial value using `reset!(m)`.
|
||||
|
@ -15,6 +15,8 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
|
||||
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
||||
|
||||
```julia
|
||||
using LinearAlgebra
|
||||
|
||||
penalty() = norm(m.W) + norm(m.b)
|
||||
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
||||
```
|
||||
@ -48,15 +50,17 @@ loss(rand(28^2), rand(10))
|
||||
One can also easily add per-layer regularisation via the `activations` function:
|
||||
|
||||
```julia
|
||||
julia> using Flux: activations
|
||||
|
||||
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
|
||||
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
|
||||
julia> activations(c, rand(10))
|
||||
3-element Array{Any,1}:
|
||||
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
|
||||
param([0.0330606, -0.456104])
|
||||
param([0.61991, 0.38009])
|
||||
Float32[0.84682214, 0.6704139, 0.42177814, 0.257832, 0.36255655]
|
||||
Float32[0.1501253, 0.073269576]
|
||||
Float32[0.5192045, 0.48079553]
|
||||
|
||||
julia> sum(norm, ans)
|
||||
2.639678767773633 (tracked)
|
||||
2.1166067f0
|
||||
```
|
||||
|
@ -14,11 +14,11 @@ Which means allocations occur much faster.
|
||||
And you use less memory.
|
||||
|
||||
|
||||
## Make sure your custom activation functions preserve the type of their inputs
|
||||
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||
## Make sure your activation and loss functions preserve the type of their inputs
|
||||
Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||
they should also preserve the type of their inputs.
|
||||
|
||||
A very artificial example using an activatioon function like
|
||||
A very artificial example using an activation function like
|
||||
|
||||
```
|
||||
my_tanh(x) = Float64(tanh(x))
|
||||
@ -26,6 +26,7 @@ A very artificial example using an activatioon function like
|
||||
|
||||
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
|
||||
because it results in having to use slow mixed type multiplication in the dense layers.
|
||||
Similar situations can occur in the loss function during backpropagation.
|
||||
|
||||
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
|
||||
you will see a large slow-down
|
||||
@ -41,7 +42,7 @@ While one could change your activation function (e.g. to use `0.01f0x`) to avoid
|
||||
the idiomatic (and safe way) is to use `oftype`.
|
||||
|
||||
```
|
||||
leaky_tanh(x) = oftype(x/1, 0.01) + tanh(x)
|
||||
leaky_tanh(x) = oftype(x/1, 0.01)x + tanh(x)
|
||||
```
|
||||
|
||||
|
||||
@ -60,7 +61,7 @@ end
|
||||
|
||||
It is much faster to concatenate them into a matrix,
|
||||
as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications.
|
||||
Even though this means allocating new memory to store them contiguously.
|
||||
The improvement is enough that it is worthwhile allocating new memory to store them contiguously.
|
||||
|
||||
```julia
|
||||
x_batch = reduce(hcat, xs)
|
||||
@ -73,4 +74,4 @@ end
|
||||
```
|
||||
|
||||
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
|
||||
This will avoid the splatting penality, and will hit the optimised `reduce` method.
|
||||
This will avoid the splatting penalty, and will hit the optimised `reduce` method.
|
||||
|
@ -53,7 +53,7 @@ julia> using Flux
|
||||
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
|
||||
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
|
||||
|
||||
julia> weights = Tracker.data.(params(model));
|
||||
julia> weights = params(model);
|
||||
|
||||
julia> using BSON: @save
|
||||
|
||||
@ -113,6 +113,6 @@ You can even store optimiser state alongside the model, to resume training
|
||||
exactly where you left off.
|
||||
|
||||
```julia
|
||||
opt = ADAM(params(model))
|
||||
opt = ADAM()
|
||||
@save "model-$(now()).bson" model opt
|
||||
```
|
||||
|
@ -3,25 +3,25 @@
|
||||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
||||
|
||||
```julia
|
||||
using Flux, Flux.Tracker
|
||||
using Flux
|
||||
|
||||
W = param(rand(2, 5))
|
||||
b = param(rand(2))
|
||||
W = rand(2, 5)
|
||||
b = rand(2)
|
||||
|
||||
predict(x) = W*x .+ b
|
||||
predict(x) = (W * x) .+ b
|
||||
loss(x, y) = sum((predict(x) .- y).^2)
|
||||
|
||||
x, y = rand(5), rand(2) # Dummy data
|
||||
l = loss(x, y) # ~ 3
|
||||
|
||||
θ = Params([W, b])
|
||||
grads = Tracker.gradient(() -> loss(x, y), θ)
|
||||
grads = gradient(() -> loss(x, y), θ)
|
||||
```
|
||||
|
||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: grad, update!
|
||||
using Flux: update!
|
||||
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
@ -58,3 +58,83 @@ AMSGrad
|
||||
NADAM
|
||||
ADAMW
|
||||
```
|
||||
|
||||
## Optimiser Interface
|
||||
|
||||
Flux's optimsers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient.
|
||||
|
||||
In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example.
|
||||
|
||||
```julia
|
||||
mutable struct Momentum
|
||||
eta
|
||||
rho
|
||||
velocity
|
||||
end
|
||||
|
||||
Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict())
|
||||
```
|
||||
|
||||
The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked.
|
||||
|
||||
```julia
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
```
|
||||
|
||||
This is the basic definition of a Momentum update rule given by:
|
||||
|
||||
```math
|
||||
v = ρ * v - η * Δ
|
||||
w = w - v
|
||||
```
|
||||
|
||||
The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser.
|
||||
|
||||
Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully.
|
||||
|
||||
## Composing Optimisers
|
||||
|
||||
Flux defines a special kind of optimiser called simply as `Optimiser` which takes in a arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient
|
||||
that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Flux defines some basic decays including `ExpDecay`, `InvDecay` etc.
|
||||
|
||||
```julia
|
||||
opt = Optimiser(ExpDecay(0.001, 0.1, 1000, 1e-4), Descent())
|
||||
```
|
||||
|
||||
Here we apply exponential decay to the `Descent` optimser. The defaults of `ExpDecay` say that its learning rate will be decayed every 1000 steps.
|
||||
It is then applied like any optimser.
|
||||
|
||||
```julia
|
||||
w = randn(10, 10)
|
||||
w1 = randn(10,10)
|
||||
ps = Params([w, w1])
|
||||
|
||||
loss(x) = Flux.mse(w * x, w1 * x)
|
||||
|
||||
loss(rand(10)) # around 9
|
||||
|
||||
for t = 1:10^5
|
||||
θ = Params([w, w1])
|
||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
Flux.Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
|
||||
loss(rand(10)) # around 0.9
|
||||
```
|
||||
|
||||
In this manner it is possible to compose optimisers for some added flexibility.
|
||||
|
||||
## Decays
|
||||
|
||||
Similar to optimisers, Flux also defines some simple decays that can be used in conjunction with other optimisers, or standalone.
|
||||
|
||||
```@docs
|
||||
ExpDecay
|
||||
InvDecay
|
||||
WeightDecay
|
||||
```
|
||||
|
@ -1,8 +1,9 @@
|
||||
# Training
|
||||
|
||||
To actually train a model we need three things:
|
||||
To actually train a model we need four things:
|
||||
|
||||
* A *objective function*, that evaluates how well a model is doing given some input data.
|
||||
* The trainable parameters of the model.
|
||||
* A collection of data points that will be provided to the objective function.
|
||||
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
|
||||
|
||||
@ -32,6 +33,14 @@ Flux.train!(loss, ps, data, opt)
|
||||
|
||||
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||
|
||||
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
|
||||
|
||||
## Model parameters
|
||||
|
||||
The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the [basics](../models/basics.md) section it is explained how to create models with such parameters. The second argument of the function `Flux.train!` must be an object containing those parameters, which can be obtained from a model `m` as `params(m)`.
|
||||
|
||||
Such an object contains a reference to the model's parameters, not a copy, such that after their training, the model behaves according to their updated values.
|
||||
|
||||
## Datasets
|
||||
|
||||
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
|
||||
|
46
src/Flux.jl
46
src/Flux.jl
@ -3,30 +3,30 @@ module Flux
|
||||
# Zero Flux Given
|
||||
|
||||
using Base: tail
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
@reexport using NNlib
|
||||
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||
export gradient
|
||||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
SkipConnection,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
||||
using Tracker
|
||||
using Tracker: data
|
||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, InvDecay, ExpDecay, WeightDecay
|
||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||
|
||||
|
||||
using CuArrays
|
||||
const use_cuda = Ref(false)
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
include("treelike.jl")
|
||||
include("functor.jl")
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
@ -36,6 +36,28 @@ include("layers/normalise.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||
include("deprecations.jl")
|
||||
|
||||
function __init__()
|
||||
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
|
||||
|
||||
# we don't want to include the CUDA module when precompiling,
|
||||
# or we could end up replacing it at run time (triggering a warning)
|
||||
precompiling && return
|
||||
|
||||
if !CuArrays.functional()
|
||||
# nothing to do here, and either CuArrays or one of its dependencies will have warned
|
||||
else
|
||||
use_cuda[] = true
|
||||
|
||||
# FIXME: this functionality should be conditional at run time by checking `use_cuda`
|
||||
# (or even better, get moved to CuArrays.jl as much as possible)
|
||||
if CuArrays.has_cudnn()
|
||||
include(joinpath(@__DIR__, "cuda/cuda.jl"))
|
||||
else
|
||||
@warn "CuArrays.jl did not find libcudnn. Some functionality will not be available."
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
end # module
|
||||
|
@ -1,38 +1,9 @@
|
||||
module CUDA
|
||||
|
||||
using ..CuArrays
|
||||
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||
using Pkg.TOML
|
||||
|
||||
function version_check()
|
||||
major_version = 1
|
||||
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
|
||||
project = TOML.parse(String(read(project)))
|
||||
version = VersionNumber(get(project, "version", "0.0.0"))
|
||||
if version.major != major_version
|
||||
@warn """
|
||||
Flux is only supported with CuArrays v$major_version.x.
|
||||
Try running `] pin CuArrays@$major_version`.
|
||||
"""
|
||||
end
|
||||
end
|
||||
|
||||
version_check()
|
||||
|
||||
if !applicable(CuArray{UInt8}, undef, 1)
|
||||
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
if isdefined(CuArrays, :libcudnn_handle)
|
||||
handle() = CuArrays.libcudnn_handle[]
|
||||
else
|
||||
handle() = CuArrays.CUDNN.handle()
|
||||
end
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
else
|
||||
@warn("CUDNN is not installed, some functionality will not be available.")
|
||||
end
|
||||
using CuArrays: CUDNN
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
|
||||
end
|
||||
|
@ -1,228 +1,8 @@
|
||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
import ..Flux: data
|
||||
using LinearAlgebra
|
||||
import CuArrays.CUDNN: batchnorm, ∇batchnorm
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
end
|
||||
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
|
||||
|
||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||
d = [C_NULL]
|
||||
s = Csize_t[0]
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
||||
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
|
||||
desc,handle(),ρ,states,length(states),seed)
|
||||
finalizer(desc) do x
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return desc
|
||||
end
|
||||
|
||||
const BATCHNORM_SPATIAL = 1
|
||||
const BATCHNORM_ACTIVATION = 0
|
||||
const BATCHNORM_MIN_EPS = 1e-5
|
||||
|
||||
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
|
||||
|
||||
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
|
||||
|
||||
mutable struct BNCache
|
||||
mean
|
||||
ivar
|
||||
end
|
||||
|
||||
BNCache() = BNCache(nothing, nothing)
|
||||
|
||||
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
||||
# so reshape a 2D Tensor into 4D
|
||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
|
||||
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
|
||||
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
|
||||
|
||||
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||
y = similar(x)
|
||||
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
|
||||
alpha = alpha, beta = beta, eps = eps, training = training)
|
||||
y
|
||||
end
|
||||
|
||||
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||
momentum; cache = nothing,
|
||||
alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||
dims = _wsize(x)
|
||||
if eps < BATCHNORM_MIN_EPS
|
||||
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||
eps = BATCHNORM_MIN_EPS
|
||||
end
|
||||
xd = TensorDesc(x)
|
||||
yd = TensorDesc(y)
|
||||
gd = TensorDesc(T, dims)
|
||||
|
||||
if training
|
||||
|
||||
if cache !== nothing
|
||||
mean = zeros(CuArray{T}, dims...)
|
||||
ivar = ones(CuArray{T}, dims...)
|
||||
else
|
||||
mean = CU_NULL
|
||||
ivar = CU_NULL
|
||||
end
|
||||
|
||||
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||
Cdouble, CuPtr{T}, CuPtr{T},
|
||||
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
xd, x,
|
||||
yd, y,
|
||||
gd, g, b,
|
||||
momentum, running_mean, running_var,
|
||||
eps, mean, ivar)
|
||||
|
||||
if cache !== nothing
|
||||
cache.mean = mean
|
||||
cache.ivar = ivar
|
||||
end
|
||||
else
|
||||
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||
CuPtr{T}, CuPtr{T},
|
||||
Cdouble),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
xd, x,
|
||||
yd, y,
|
||||
gd, g, b,
|
||||
running_mean, running_var,
|
||||
eps)
|
||||
end
|
||||
end
|
||||
|
||||
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
|
||||
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
|
||||
alpha = alpha, beta = beta, training = training)
|
||||
(dg, db, dropdims(dx, dims = (1, 2)))
|
||||
end
|
||||
|
||||
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||
dg = similar(g)
|
||||
db = similar(b)
|
||||
dx = similar(x)
|
||||
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
|
||||
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
||||
(dg, db, dx)
|
||||
end
|
||||
|
||||
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||
momentum; cache = nothing, eps = T(1e-5),
|
||||
alpha = T(1), beta = T(0),
|
||||
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||
if training
|
||||
xd = TensorDesc(x)
|
||||
dyd = TensorDesc(dy)
|
||||
dxd = TensorDesc(dx)
|
||||
gd = TensorDesc(T, _wsize(x))
|
||||
if cache !== nothing
|
||||
mean, ivar = cache.mean, cache.ivar
|
||||
info("mean and ivar are fetched from the cache")
|
||||
else
|
||||
mean, ivar = CU_NULL, CU_NULL
|
||||
end
|
||||
|
||||
if eps < BATCHNORM_MIN_EPS
|
||||
eps = BATCHNORM_MIN_EPS
|
||||
end
|
||||
|
||||
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
|
||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
|
||||
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||
xd, x,
|
||||
dyd, dy,
|
||||
dxd, dx,
|
||||
gd, g, dg, db,
|
||||
eps, mean, ivar)
|
||||
else
|
||||
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
|
||||
dx .= dy .* reshape(g, _wsize(x)) .* ivar
|
||||
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
|
||||
db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
|
||||
end
|
||||
end
|
||||
|
||||
# Flux Interface
|
||||
|
||||
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing)
|
||||
|
@ -1,325 +1,91 @@
|
||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
using LinearAlgebra
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||
const LSTM = 2 # LSTM with no peephole connections
|
||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
||||
|
||||
const LINEAR_INPUT = 0
|
||||
const SKIP_INPUT = 1
|
||||
|
||||
const UNIDIRECTIONAL = 0
|
||||
const BIDIRECTIONAL = 1
|
||||
|
||||
const RNN_ALGO_STANDARD = 0
|
||||
const RNN_ALGO_PERSIST_STATIC = 1
|
||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
|
||||
# param layout:
|
||||
# RNN: [weight, bias] × [input, hidden]
|
||||
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
|
||||
(wx, wh), bias
|
||||
end
|
||||
|
||||
mutable struct RNNDesc{T}
|
||||
mode::Int
|
||||
input::Int
|
||||
hidden::Int
|
||||
params::CuVector{T}
|
||||
weights::NTuple{2,CuMatrix{T}}
|
||||
bias::CuVector{T}
|
||||
ptr::Ptr{Nothing}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
||||
|
||||
function rnnParamSize(T, r, input)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
|
||||
handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||
return Int(size[])÷sizeof(T)
|
||||
end
|
||||
|
||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||
ngates(r::RNNDesc) = ngates(r.mode)
|
||||
|
||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
|
||||
|
||||
dropoutDesc = DropoutDesc(0)
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
||||
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd) do x
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return rd
|
||||
end
|
||||
|
||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
|
||||
handle(), r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(undef, 1)]
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(undef, bytes))
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||
|
||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
|
||||
handle(), r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, reserve=nothing) where T
|
||||
if reserve == nothing
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
CuPtr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace))
|
||||
else
|
||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
end
|
||||
end
|
||||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Nothing) = C_NULL, CU_NULL
|
||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
end
|
||||
|
||||
# TODO: can we just manipulate strides here?
|
||||
# TODO: should use repmat, but this isn't implemented.
|
||||
hBatch(x::AbstractVector, h::CuVector) = h
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||
h = hBatch(x, h_)
|
||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||
@assert size(x, 1) == rnn.input
|
||||
@assert size(h, 1) == rnn.hidden
|
||||
@assert size(x, 2) == size(h, 2)
|
||||
seqLength = 1
|
||||
xdesc = xDesc(x)
|
||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||
ho = similar(h)
|
||||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train == Val{true} ?
|
||||
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
nothing
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve)
|
||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||
return train == Val{true} ? (reserve, result) : result
|
||||
end
|
||||
|
||||
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
||||
forward(rnn, x, h, c, Val{true})
|
||||
|
||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
|
||||
CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
# Same as above, any more efficient way?
|
||||
dy = dy_ isa Integer ? zero(y) : dy_
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], reserve)
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
backwardData(rnn, y, dy, dho, hx, reserve) =
|
||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
||||
|
||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, #x
|
||||
Ptr{Nothing}, CuPtr{T}, #hx
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, #y
|
||||
CuPtr{Nothing}, Csize_t, #ws
|
||||
Ptr{Nothing}, CuPtr{T}, #dw
|
||||
CuPtr{Nothing}, Csize_t), #rs
|
||||
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zero(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], reserve)
|
||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||
end
|
||||
|
||||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using .CuArrays.CUDAnative
|
||||
using .CuArrays: @cuindex, cudims
|
||||
using CuArrays.CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
Wi, Wh = d.weights
|
||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||
copy_transpose!(d.bias, Flux.data(m.b))
|
||||
return
|
||||
end
|
||||
|
||||
function RNNDesc(m::CuRNNs{T}) where T
|
||||
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
m isa CuGRU ? GRU : LSTM
|
||||
r = RNNDesc{T}(mode, i, h)
|
||||
(m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
|
||||
m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
|
||||
r = CUDNN.RNNDesc{T}(mode, i, h)
|
||||
return r
|
||||
end
|
||||
|
||||
const descs = WeakKeyDict()
|
||||
|
||||
function desc(rnn)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
||||
copyparams!(rnn, d)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
|
||||
CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
|
||||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker
|
||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
||||
import Zygote
|
||||
using Zygote: @adjoint
|
||||
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
y, h′ = CUDNN.forward(desc(m), x, h)
|
||||
return h′, y
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
y, h′ = CUDNN.forward(desc(m), x, h)
|
||||
return h′, y
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2])
|
||||
return (h′, c′), y
|
||||
end
|
||||
|
||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
coerce_cuda(x::Union{CuArray,Nothing}) = x
|
||||
coerce_cuda(x::Tuple) = coerce_cuda.(x)
|
||||
|
||||
coerce_cuda(x::AbstractArray) = x .+ CuArrays.fill(0)
|
||||
|
||||
function struct_grad!(cx::Zygote.Context, x, x̄)
|
||||
for f in fieldnames(typeof(x))
|
||||
Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f))
|
||||
end
|
||||
dx = Zygote.grad_mut(cx, x)
|
||||
dx[] = Zygote.accum(dx[], x̄)
|
||||
return dx
|
||||
end
|
||||
|
||||
for RNN in (CuRNN, CuGRU)
|
||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
(y, ho), back = CUDNN.pullback(desc(m), x, h)
|
||||
(ho, y), function (Δ)
|
||||
dho, dy = coerce_cuda(Δ) # Support FillArrays etc.
|
||||
m̄ = back(dy, dho)
|
||||
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing))
|
||||
(dm, unbroadcast(h, m̄.h), m̄.x)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
(y, ho, co), back = CUDNN.pullback(desc(m), x, h, c)
|
||||
((ho, co), y), function (Δ)
|
||||
dhc, dy = coerce_cuda(Δ) # Support FillArrays etc.
|
||||
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||
m̄ = back(dy, dho, dco)
|
||||
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing))
|
||||
(dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x)
|
||||
end
|
||||
end
|
||||
|
@ -1,8 +1,4 @@
|
||||
|
||||
"""
|
||||
|
||||
Iris
|
||||
|
||||
Fisher's classic iris dataset.
|
||||
|
||||
Measurements from 3 different species of iris: setosa, versicolor and
|
||||
@ -39,6 +35,8 @@ Get the labels of the iris dataset, a 150 element array of strings listing the
|
||||
species of each example.
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux
|
||||
|
||||
julia> labels = Flux.Data.Iris.labels();
|
||||
|
||||
julia> summary(labels)
|
||||
@ -63,6 +61,8 @@ elements. It has a row for each feature (sepal length, sepal width,
|
||||
petal length, petal width) and a column for each example.
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux
|
||||
|
||||
julia> features = Flux.Data.Iris.features();
|
||||
|
||||
julia> summary(features)
|
||||
@ -81,6 +81,5 @@ function features()
|
||||
iris = readdlm(deps("iris.data"), ',')
|
||||
Matrix{Float64}(iris[1:end, 1:4]')
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
|
||||
|
2
src/deprecations.jl
Normal file
2
src/deprecations.jl
Normal file
@ -0,0 +1,2 @@
|
||||
@deprecate param(x) x
|
||||
@deprecate data(x) x
|
85
src/functor.jl
Normal file
85
src/functor.jl
Normal file
@ -0,0 +1,85 @@
|
||||
import Adapt: adapt, adapt_storage
|
||||
using Zygote: IdSet
|
||||
|
||||
functor(x) = (), _ -> x
|
||||
|
||||
functor(x::Tuple) = x, y -> y
|
||||
functor(x::NamedTuple) = x, y -> y
|
||||
|
||||
functor(x::AbstractArray) = x, y -> y
|
||||
functor(x::AbstractArray{<:Number}) = (), _ -> x
|
||||
|
||||
function makefunctor(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
|
||||
end
|
||||
end
|
||||
|
||||
function functorm(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
macro functor(args...)
|
||||
functorm(args...)
|
||||
end
|
||||
|
||||
isleaf(x) = functor(x)[1] === ()
|
||||
|
||||
function fmap1(f, x)
|
||||
func, re = functor(x)
|
||||
re(map(f, func))
|
||||
end
|
||||
|
||||
function fmap(f, x; cache = IdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
trainable(m) = functor(m)[1]
|
||||
|
||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||
|
||||
function params!(p::Params, x, seen = IdSet())
|
||||
x in seen && return
|
||||
push!(seen, x)
|
||||
for child in trainable(x)
|
||||
params!(p, child, seen)
|
||||
end
|
||||
end
|
||||
|
||||
function params(m...)
|
||||
ps = Params()
|
||||
params!(ps, m)
|
||||
return ps
|
||||
end
|
||||
|
||||
# Deprecated stuff
|
||||
macro treelike(args...)
|
||||
functorm(args...)
|
||||
end
|
||||
mapleaves(f, x) = fmap(f, x)
|
||||
|
||||
function loadparams!(m, xs)
|
||||
for (p, x) in zip(params(m), xs)
|
||||
size(p) == size(x) ||
|
||||
error("Expected param size $(size(p)), got $(size(x))")
|
||||
copyto!(p, x)
|
||||
end
|
||||
end
|
||||
|
||||
# CPU/GPU movement conveniences
|
||||
|
||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||
|
||||
gpu(x) = use_cuda[] ? fmap(CuArrays.cu, x) : x
|
||||
|
||||
# Precision
|
||||
|
||||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
||||
|
||||
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
|
||||
|
||||
f32(m) = paramtype(Float32, m)
|
||||
f64(m) = paramtype(Float64, m)
|
@ -24,8 +24,7 @@ end
|
||||
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||
Base.iterate, Base.lastindex
|
||||
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
functor(c::Chain) = c.layers, ls -> Chain(ls...)
|
||||
|
||||
applychain(::Tuple{}, x) = x
|
||||
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||
@ -45,19 +44,23 @@ end
|
||||
# it might be replaced in the future for better performance
|
||||
# see issue https://github.com/FluxML/Flux.jl/issues/702
|
||||
# Johnny Chen -- @johnnychen94
|
||||
# only slightly changed to better handle interaction with Zygote @dsweber2
|
||||
"""
|
||||
activations(c::Chain, input)
|
||||
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||
"""
|
||||
function activations(c::Chain, input)
|
||||
rst = []
|
||||
for l in c
|
||||
x = get(rst, length(rst), input)
|
||||
push!(rst, l(x))
|
||||
end
|
||||
return rst
|
||||
extraChain(c.layers, input)
|
||||
end
|
||||
|
||||
function extraChain(fs::Tuple, x)
|
||||
res = first(fs)(x)
|
||||
return (res, extraChain(Base.tail(fs), res)...)
|
||||
end
|
||||
|
||||
extraChain(::Tuple{}, x) = ()
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
@ -89,10 +92,10 @@ Dense(W, b) = Dense(W, b, identity)
|
||||
|
||||
function Dense(in::Integer, out::Integer, σ = identity;
|
||||
initW = glorot_uniform, initb = zeros)
|
||||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||
return Dense(initW(out, in), initb(out), σ)
|
||||
end
|
||||
|
||||
@treelike Dense
|
||||
@functor Dense
|
||||
|
||||
function (a::Dense)(x::AbstractArray)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
@ -110,7 +113,7 @@ end
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
@ -129,9 +132,9 @@ struct Diagonal{T}
|
||||
end
|
||||
|
||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
Diagonal(initα(in), initβ(in))
|
||||
|
||||
@treelike Diagonal
|
||||
@functor Diagonal
|
||||
|
||||
function (a::Diagonal)(x)
|
||||
α, β = a.α, a.β
|
||||
@ -184,41 +187,42 @@ function Maxout(f, n_alts)
|
||||
return Maxout(over)
|
||||
end
|
||||
|
||||
@treelike Maxout
|
||||
@functor Maxout
|
||||
|
||||
function (mo::Maxout)(input::AbstractArray)
|
||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||
end
|
||||
|
||||
"""
|
||||
SkipConnection(layers...)
|
||||
SkipConnection(layers, connection)
|
||||
|
||||
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
|
||||
and a shortcut connection linking the input to the block to the
|
||||
output through a user-supplied callable.
|
||||
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
|
||||
plus a shortcut connection. The connection function will combine the result of the layers
|
||||
with the original input, to give the final output.
|
||||
|
||||
`SkipConnection` requires the output dimension to be the same as the input.
|
||||
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
|
||||
and requires the output of the layers to be the same shape as the input.
|
||||
Here is a more complicated example:
|
||||
```
|
||||
m = Conv((3,3), 4=>7, pad=(1,1))
|
||||
x = ones(5,5,4,10);
|
||||
size(m(x)) == (5, 5, 7, 10)
|
||||
|
||||
A 'ResNet'-type skip-connection with identity shortcut would simply be
|
||||
```julia
|
||||
SkipConnection(layer, (a,b) -> a + b)
|
||||
sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
|
||||
size(sm(x)) == (5, 5, 11, 10)
|
||||
```
|
||||
"""
|
||||
|
||||
struct SkipConnection
|
||||
layers
|
||||
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
||||
end
|
||||
|
||||
@treelike SkipConnection
|
||||
@functor SkipConnection
|
||||
|
||||
function (skip::SkipConnection)(input)
|
||||
#We apply the layers to the input and return the result of the application of the layers and the original input
|
||||
skip.connection(skip.layers(input), input)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, b::SkipConnection)
|
||||
print(io, "SkipConnection(")
|
||||
join(io, b.layers, ", ")
|
||||
print(io, ")")
|
||||
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
|
||||
end
|
||||
|
@ -42,10 +42,10 @@ end
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
Conv(init(k..., ch...), zeros(ch[2]), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike Conv
|
||||
@functor Conv
|
||||
|
||||
function (c::Conv)(x::AbstractArray)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
@ -74,8 +74,10 @@ end
|
||||
|
||||
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct ConvTranspose{N,M,F,A,V}
|
||||
@ -97,10 +99,10 @@ end
|
||||
|
||||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
||||
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike ConvTranspose
|
||||
@functor ConvTranspose
|
||||
|
||||
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||
@ -116,6 +118,9 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||
)
|
||||
end
|
||||
|
||||
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
|
||||
@nograd conv_transpose_dims
|
||||
|
||||
function (c::ConvTranspose)(x::AbstractArray)
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
@ -169,8 +174,8 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N
|
||||
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
|
||||
return DepthwiseConv(
|
||||
param(init(k..., div(ch[2], ch[1]), ch[1])),
|
||||
param(zeros(ch[2])),
|
||||
init(k..., div(ch[2], ch[1]), ch[1]),
|
||||
zeros(ch[2]),
|
||||
σ;
|
||||
stride = stride,
|
||||
pad = pad,
|
||||
@ -178,7 +183,7 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
|
||||
)
|
||||
end
|
||||
|
||||
@treelike DepthwiseConv
|
||||
@functor DepthwiseConv
|
||||
|
||||
function (c::DepthwiseConv)(x)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
@ -198,6 +203,7 @@ end
|
||||
|
||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
CrossCor(size, in=>out)
|
||||
CrossCor(size, in=>out, relu)
|
||||
@ -238,10 +244,10 @@ end
|
||||
|
||||
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike CrossCor
|
||||
@functor CrossCor
|
||||
|
||||
function crosscor(x, w, ddims::DenseConvDims)
|
||||
ddims = DenseConvDims(ddims, F=true)
|
||||
|
@ -1,17 +1,20 @@
|
||||
"""
|
||||
testmode!(m)
|
||||
testmode!(m, false)
|
||||
istraining() = false
|
||||
|
||||
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
|
||||
(or back to training mode with `false`).
|
||||
"""
|
||||
function testmode!(m, val::Bool=true)
|
||||
prefor(x -> _testmode!(x, val), m)
|
||||
return m
|
||||
@adjoint istraining() = true, _ -> nothing
|
||||
|
||||
_dropout_shape(s, ::Colon) = size(s)
|
||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
dropout(x, p; dims = :) = x
|
||||
|
||||
@adjoint function dropout(x, p; dims = :)
|
||||
y = rand!(similar(x, _dropout_shape(x, dims)))
|
||||
y .= _dropout_kernel.(y, p, 1 - p)
|
||||
return x .* y, Δ -> (Δ .* y, nothing)
|
||||
end
|
||||
|
||||
_testmode!(m, test) = nothing
|
||||
|
||||
"""
|
||||
Dropout(p, dims = :)
|
||||
|
||||
@ -19,48 +22,25 @@ A Dropout layer. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
|
||||
|
||||
Does nothing to the input once in [`testmode!`](@ref).
|
||||
"""
|
||||
mutable struct Dropout{F}
|
||||
mutable struct Dropout{F,D}
|
||||
p::F
|
||||
dims::Union{Colon, Int, NTuple{N, Int} where N}
|
||||
active::Bool
|
||||
dims::D
|
||||
end
|
||||
|
||||
function Dropout(p; dims = :)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
Dropout{typeof(p)}(p, dims, true)
|
||||
Dropout{typeof(p),typeof(dims)}(p, dims)
|
||||
end
|
||||
|
||||
_dropout_shape(s, ::Colon) = size(s)
|
||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims)
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
|
||||
"""
|
||||
dropout(x, p; dims = :)
|
||||
|
||||
The dropout function. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||
used as a regularisation, i.e. it reduces overfitting during training.
|
||||
"""
|
||||
function dropout(x, p; dims = :)
|
||||
y = similar(x, _dropout_shape(x, dims))
|
||||
rand!(y)
|
||||
y .= _dropout_kernel.(y, p, 1 - p)
|
||||
return x .* y
|
||||
function Base.show(io::IO, d::Dropout)
|
||||
print(io, "Dropout(", d.p)
|
||||
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
return dropout(x, a.p; dims = a.dims)
|
||||
end
|
||||
|
||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||
|
||||
"""
|
||||
AlphaDropout(p)
|
||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||
@ -69,29 +49,25 @@ The AlphaDropout layer ensures that mean and variance of activations remains the
|
||||
"""
|
||||
mutable struct AlphaDropout{F}
|
||||
p::F
|
||||
active::Bool
|
||||
end
|
||||
|
||||
function AlphaDropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
AlphaDropout(p,true)
|
||||
function AlphaDropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
new{typeof(p)}(p)
|
||||
end
|
||||
end
|
||||
|
||||
function (a::AlphaDropout)(x)
|
||||
a.active || return x
|
||||
istraining() || return x
|
||||
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||
α = eltype(x)(1.6732632423543772848170429916717)
|
||||
α1 = eltype(x)(-λ*α)
|
||||
noise = randn(eltype(x), size(x))
|
||||
x = @. x*(noise > (1 - a.p)) + α1 * (noise <= (1 - a.p))
|
||||
x = @. x*(noise > (1 - a.p)) + α1 * (noise < (1 - a.p))
|
||||
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
|
||||
B = -A * α1 * (1 - a.p)
|
||||
x = @. A * x + B
|
||||
return x
|
||||
end
|
||||
|
||||
_testmode!(a::AlphaDropout, test) = (a.active = !test)
|
||||
|
||||
"""
|
||||
LayerNorm(h::Integer)
|
||||
|
||||
@ -106,7 +82,7 @@ end
|
||||
LayerNorm(h::Integer) =
|
||||
LayerNorm(Diagonal(h))
|
||||
|
||||
@treelike LayerNorm
|
||||
@functor LayerNorm
|
||||
|
||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||
|
||||
@ -151,25 +127,25 @@ mutable struct BatchNorm{F,V,W,N}
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
|
||||
trainable(bn::BatchNorm) = (bn.β, bn.γ)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
dims = length(size(x))
|
||||
channels = size(x, dims-1)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = channels
|
||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||
m = div(prod(size(x)), channels)
|
||||
γ = reshape(BN.γ, affine_shape...)
|
||||
β = reshape(BN.β, affine_shape...)
|
||||
if !BN.active
|
||||
if !istraining()
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ² = reshape(BN.σ², affine_shape...)
|
||||
ϵ = BN.ϵ
|
||||
@ -178,11 +154,12 @@ function (BN::BatchNorm)(x)
|
||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
ϵ = convert(T, BN.ϵ)
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
|
||||
mtm = BN.momentum
|
||||
S = eltype(BN.μ)
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :))
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :))
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
@ -191,13 +168,7 @@ function (BN::BatchNorm)(x)
|
||||
end
|
||||
end
|
||||
|
||||
children(BN::BatchNorm) =
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||
@functor BatchNorm
|
||||
|
||||
function Base.show(io::IO, l::BatchNorm)
|
||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||
@ -244,13 +215,14 @@ mutable struct InstanceNorm{F,V,W,N}
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
|
||||
trainable(in::InstanceNorm) = (in.β, in.γ)
|
||||
|
||||
function (in::InstanceNorm)(x)
|
||||
size(x, ndims(x)-1) == length(in.β) ||
|
||||
@ -261,28 +233,26 @@ function (in::InstanceNorm)(x)
|
||||
dims = length(size(x))
|
||||
c = size(x, dims-1)
|
||||
bs = size(x, dims)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = c
|
||||
affine_shape[end] = bs
|
||||
m = prod(size(x)[1:end-2])
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x))
|
||||
m = div(prod(size(x)), c*bs)
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !in.active
|
||||
if !istraining()
|
||||
μ = expand_inst(in.μ, affine_shape)
|
||||
σ² = expand_inst(in.σ², affine_shape)
|
||||
ϵ = in.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, in.ϵ))
|
||||
ϵ = convert(T, in.ϵ)
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
|
||||
S = eltype(in.μ)
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, in.momentum))
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||
mtm = in.momentum
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2)
|
||||
end
|
||||
|
||||
let λ = in.λ
|
||||
@ -291,13 +261,7 @@ function (in::InstanceNorm)(x)
|
||||
end
|
||||
end
|
||||
|
||||
children(in::InstanceNorm) =
|
||||
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
||||
|
||||
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
||||
|
||||
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
||||
@functor InstanceNorm
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
@ -327,7 +291,6 @@ m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
|
||||
Link : https://arxiv.org/pdf/1803.08494.pdf
|
||||
"""
|
||||
|
||||
mutable struct GroupNorm{F,V,W,N,T}
|
||||
G::T # number of groups
|
||||
λ::F # activation function
|
||||
@ -337,13 +300,14 @@ mutable struct GroupNorm{F,V,W,N,T}
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum, true)
|
||||
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum)
|
||||
|
||||
trainable(gn::GroupNorm) = (gn.β, gn.γ)
|
||||
|
||||
function(gn::GroupNorm)(x)
|
||||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||
@ -355,20 +319,17 @@ function(gn::GroupNorm)(x)
|
||||
channels = size(x, dims-1)
|
||||
batches = size(x,dims)
|
||||
channels_per_group = div(channels,groups)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||
|
||||
# Output reshaped to (W,H...,C/G,G,N)
|
||||
affine_shape[end-1] = channels
|
||||
|
||||
μ_affine_shape = ones(Int,dims + 1)
|
||||
μ_affine_shape[end-1] = groups
|
||||
μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1)
|
||||
|
||||
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||
γ = reshape(gn.γ, affine_shape...)
|
||||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !gn.active
|
||||
if !istraining()
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
@ -380,12 +341,12 @@ function(gn::GroupNorm)(x)
|
||||
μ = mean(y, dims = axes)
|
||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||
|
||||
ϵ = data(convert(T, gn.ϵ))
|
||||
ϵ = convert(T, gn.ϵ)
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, gn.momentum))
|
||||
|
||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
|
||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
|
||||
mtm = gn.momentum
|
||||
S = eltype(gn.μ)
|
||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2)
|
||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2)
|
||||
end
|
||||
|
||||
let λ = gn.λ
|
||||
@ -397,13 +358,7 @@ function(gn::GroupNorm)(x)
|
||||
end
|
||||
end
|
||||
|
||||
children(gn::GroupNorm) =
|
||||
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
||||
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
|
||||
@functor GroupNorm
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||
|
@ -1,5 +1,5 @@
|
||||
gate(h, n) = (1:h) .+ h*(n-1)
|
||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
|
||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||
|
||||
# Stateful recurrence
|
||||
@ -38,25 +38,10 @@ function (m::Recur)(xs...)
|
||||
return y
|
||||
end
|
||||
|
||||
@treelike Recur cell, init
|
||||
@functor Recur cell, init
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
"""
|
||||
truncate!(rnn)
|
||||
|
||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
||||
state is preserved. See also `reset!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = Tracker.data(rnn.state)
|
||||
"""
|
||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
||||
|
||||
"""
|
||||
reset!(rnn)
|
||||
|
||||
@ -67,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = hidden(rnn.cell)
|
||||
"""
|
||||
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
|
||||
reset!(m::Recur) = (m.state = m.init)
|
||||
reset!(m) = foreach(reset!, functor(m)[1])
|
||||
|
||||
flip(f, xs) = reverse(f.(reverse(xs)))
|
||||
|
||||
@ -83,8 +69,8 @@ end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(init(out)), param(zeros(out)))
|
||||
RNNCell(σ, init(out, in), init(out, out),
|
||||
init(out), zeros(out))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
@ -94,7 +80,7 @@ end
|
||||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
@treelike RNNCell
|
||||
@functor RNNCell
|
||||
|
||||
function Base.show(io::IO, l::RNNCell)
|
||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||
@ -122,9 +108,9 @@ end
|
||||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
|
||||
param(zeros(out)), param(zeros(out)))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
||||
zeros(out), zeros(out))
|
||||
cell.b[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
|
||||
@ -142,7 +128,7 @@ end
|
||||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
@treelike LSTMCell
|
||||
@functor LSTMCell
|
||||
|
||||
Base.show(io::IO, l::LSTMCell) =
|
||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||
@ -168,8 +154,8 @@ mutable struct GRUCell{A,V}
|
||||
end
|
||||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(init(out*3)), param(zeros(out)))
|
||||
GRUCell(init(out * 3, in), init(out * 3, out),
|
||||
init(out * 3), zeros(out))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
@ -183,7 +169,7 @@ end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
@treelike GRUCell
|
||||
@functor GRUCell
|
||||
|
||||
Base.show(io::IO, l::GRUCell) =
|
||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||
|
@ -1,13 +1,24 @@
|
||||
using CuArrays
|
||||
using NNlib: logsoftmax, logσ
|
||||
|
||||
# Cost functions
|
||||
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
-sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing)
|
||||
return -sum(y .* log.(ŷ)) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Number)
|
||||
return -sum(y .* log.(ŷ)) .* weight * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::AbstractVector)
|
||||
return -sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight)
|
||||
|
||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
@ -25,6 +36,9 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica
|
||||
"""
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
# Re-definition to fix interaction with CuArrays.
|
||||
CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y)
|
||||
|
||||
@ -39,18 +53,34 @@ but it is more numerically stable.
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
|
||||
# Re-definition to fix interaction with CuArrays.
|
||||
CuArrays.@cufunc logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
|
||||
"""
|
||||
normalise(x::AbstractArray; dims=1)
|
||||
|
||||
Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
|
||||
Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns.
|
||||
|
||||
julia> a = reshape(collect(1:9), 3, 3)
|
||||
3×3 Array{Int64,2}:
|
||||
1 4 7
|
||||
2 5 8
|
||||
3 6 9
|
||||
|
||||
julia> normalise(a)
|
||||
3×3 Array{Float64,2}:
|
||||
-1.22474 -1.22474 -1.22474
|
||||
0.0 0.0 0.0
|
||||
1.22474 1.22474 1.22474
|
||||
|
||||
julia> normalise(a, dims=2)
|
||||
3×3 Array{Float64,2}:
|
||||
-1.22474 0.0 1.22474
|
||||
-1.22474 0.0 1.22474
|
||||
-1.22474 0.0 1.22474
|
||||
"""
|
||||
function normalise(x::AbstractArray; dims=1)
|
||||
μ′ = mean(x, dims = dims)
|
||||
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
||||
return (x .- μ′) ./ σ′
|
||||
end
|
||||
|
||||
function normalise(x::AbstractArray, dims)
|
||||
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
||||
normalise(x, dims = dims)
|
||||
end
|
||||
|
@ -37,12 +37,10 @@ import Adapt: adapt, adapt_structure
|
||||
|
||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
|
||||
"""
|
||||
onehot(l, labels[, unk])
|
||||
@ -54,17 +52,19 @@ it will error.
|
||||
## Examples
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux: onehot
|
||||
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
false
|
||||
true
|
||||
false
|
||||
0
|
||||
1
|
||||
0
|
||||
|
||||
julia> onehot(:c, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
false
|
||||
false
|
||||
true
|
||||
0
|
||||
0
|
||||
1
|
||||
```
|
||||
"""
|
||||
function onehot(l, labels)
|
||||
@ -88,12 +88,13 @@ Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `label
|
||||
## Examples
|
||||
|
||||
```jldoctest
|
||||
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||
3×3 Flux.OneHotMatrix:
|
||||
false true false
|
||||
true false true
|
||||
false false false
|
||||
julia> using Flux: onehotbatch
|
||||
|
||||
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
|
||||
0 1 0
|
||||
1 0 1
|
||||
0 0 0
|
||||
```
|
||||
"""
|
||||
onehotbatch(ls, labels, unk...) =
|
||||
@ -106,9 +107,9 @@ Base.argmax(xs::OneHotVector) = xs.ix
|
||||
|
||||
Inverse operations of [`onehot`](@ref).
|
||||
|
||||
## Examples
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux: onecold
|
||||
|
||||
julia> onecold([true, false, false], [:a, :b, :c])
|
||||
:a
|
||||
|
||||
@ -124,15 +125,6 @@ onecold(y::AbstractMatrix, labels...) =
|
||||
onecold(y::OneHotMatrix, labels...) =
|
||||
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
|
||||
|
||||
function argmax(xs...)
|
||||
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
return onecold(xs...)
|
||||
end
|
||||
|
||||
# Ambiguity hack
|
||||
|
||||
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
|
||||
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
|
||||
|
||||
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||
# TODO probably still want this as a custom adjoint Zygote
|
||||
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||
|
@ -2,11 +2,10 @@ module Optimise
|
||||
|
||||
export train!,
|
||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
include("deprecations.jl")
|
||||
|
||||
end
|
||||
|
@ -1,126 +0,0 @@
|
||||
using Base: depwarn
|
||||
using Flux: Params
|
||||
|
||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||
|
||||
# legacy update rule
|
||||
updaterule(opt, ps) = () -> _update_params!(opt, ps)
|
||||
|
||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||
|
||||
ps = params
|
||||
opt = Descent(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
|
||||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||
|
||||
ps = params
|
||||
opt = Momentum(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||
|
||||
ps = params
|
||||
opt = Nesterov(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||
|
||||
ps = params
|
||||
opt = RMSProp(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
|
||||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||
|
||||
ps = params
|
||||
opt = ADAGrad(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
|
||||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||
|
||||
ps = params
|
||||
opt = ADADelta(ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AdaMax(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AMSGrad(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = NADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAMW(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
# Old training loop
|
||||
|
||||
struct OldOptimiser
|
||||
func
|
||||
end
|
||||
|
||||
_update_params!(opt::OldOptimiser, ps) = opt.func()
|
||||
|
||||
# Train function
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
||||
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
||||
end
|
@ -7,10 +7,28 @@ const ϵ = 1e-8
|
||||
# TODO: should use weak refs
|
||||
|
||||
"""
|
||||
Descent(η)
|
||||
Descent(η)
|
||||
|
||||
Classic gradient descent optimiser with learning rate `η`.
|
||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): The amount by which the gradients are discounted before updating the weights. Defaults to `0.1`.
|
||||
|
||||
## Example
|
||||
```julia-repl
|
||||
opt = Descent() # uses default η (0.1)
|
||||
|
||||
opt = Descent(0.3) # use provided η
|
||||
|
||||
ps = params(model)
|
||||
|
||||
gs = gradient(ps) do
|
||||
loss(x, y)
|
||||
end
|
||||
|
||||
Flux.Optimise.update!(opt, ps, gs)
|
||||
```
|
||||
"""
|
||||
mutable struct Descent
|
||||
eta::Float64
|
||||
@ -23,9 +41,20 @@ function apply!(o::Descent, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
Momentum(params, η = 0.01; ρ = 0.9)
|
||||
Momentum(η, ρ)
|
||||
|
||||
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (`η`): Amount by which gradients are discounted before updating the weights. Defaults to `0.01`.
|
||||
- Momentum (`ρ`): Parameter that accelerates descent in the relevant direction and dampens oscillations. Defaults to `0.9`.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = Momentum() # uses defaults of η = 0.01 and ρ = 0.9
|
||||
|
||||
opt = Momentum(0.01, 0.99)
|
||||
```
|
||||
"""
|
||||
mutable struct Momentum
|
||||
eta::Float64
|
||||
@ -37,15 +66,26 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
|
||||
"""
|
||||
Nesterov(eta, ρ = 0.9)
|
||||
Nesterov(η, ρ)
|
||||
|
||||
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Amount by which the gradients are dicsounted berfore updating the weights. Defaults to `0.001`.
|
||||
- Nesterov Momentum (ρ): Paramters controlling the amount of nesterov momentum to be applied. Defaults to `0.9`.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = Nesterov() # uses defaults η = 0.001 and ρ = 0.9
|
||||
|
||||
opt = Nesterov(0.003, 0.95)
|
||||
```
|
||||
"""
|
||||
mutable struct Nesterov
|
||||
eta::Float64
|
||||
@ -57,18 +97,30 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@. v = ρ*v - η*Δ
|
||||
@. Δ = -d
|
||||
end
|
||||
|
||||
"""
|
||||
RMSProp(η = 0.001, ρ = 0.9)
|
||||
RMSProp(η, ρ)
|
||||
|
||||
Implements the RMSProp algortihm. Often a good choice for recurrent networks. Paramters other than learning rate generally don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Rho (ρ): Defaults to `0.9`.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = RMSProp() # uses default η = 0.001 and ρ = 0.9
|
||||
|
||||
opt = RMSProp(0.002, 0.95)
|
||||
```
|
||||
|
||||
## References
|
||||
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||
choice for recurrent networks.
|
||||
"""
|
||||
mutable struct RMSProp
|
||||
eta::Float64
|
||||
@ -80,14 +132,28 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAM(η = 0.001, β = (0.9, 0.999))
|
||||
ADAM(η, β::Tuple)
|
||||
|
||||
Implements the ADAM optimiser.
|
||||
|
||||
## Paramters
|
||||
- Learning Rate (`η`): Defaults to `0.001`.
|
||||
- Beta (`β::Tuple`): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
|
||||
## Examples
|
||||
|
||||
```julia
|
||||
opt = ADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
||||
|
||||
opt = ADAM(0.001, (0.9, 0.8))
|
||||
```
|
||||
## References
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
"""
|
||||
mutable struct ADAM
|
||||
@ -109,10 +175,67 @@ function apply!(o::ADAM, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||
RADAM(η, β::Tuple)
|
||||
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
the ∞-norm.
|
||||
Implements the rectified ADAM optimizer.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
|
||||
## Examples
|
||||
|
||||
```julia
|
||||
opt = RADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
||||
|
||||
opt = RADAM(0.001, (0.9, 0.8))
|
||||
```
|
||||
|
||||
## References
|
||||
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
|
||||
"""
|
||||
mutable struct RADAM
|
||||
eta::Float64
|
||||
beta::Tuple{Float64,Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
|
||||
|
||||
function apply!(o::RADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
ρ∞ = 2/(1-β[2])-1
|
||||
mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
ρ = ρ∞ - 2t*βp[2]/(1-βp[2])
|
||||
if ρ > 4
|
||||
r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
|
||||
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r
|
||||
else
|
||||
@. Δ = mt / (1 - βp[1]) * η
|
||||
end
|
||||
o.state[x] = (mt, vt, βp .* β, t+1)
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
AdaMax(η, β::Tuple)
|
||||
|
||||
Variant of ADAM based on ∞-norm.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = AdaMax() # uses default η and β
|
||||
|
||||
opt = AdaMax(0.001, (0.9, 0.995))
|
||||
```
|
||||
## References
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser.
|
||||
"""
|
||||
mutable struct AdaMax
|
||||
eta::Float64
|
||||
@ -133,8 +256,21 @@ function apply!(o::AdaMax, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAGrad(η = 0.1; ϵ = 1e-8)
|
||||
ADAGrad(η)
|
||||
|
||||
Implements AdaGrad. It has parameter specific learning rates based on how frequently it is updated.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.1`
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = ADAGrad() # uses default η = 0.1
|
||||
|
||||
opt = ADAGrad(0.001)
|
||||
```
|
||||
|
||||
## References
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
Parameters don't need tuning.
|
||||
"""
|
||||
@ -147,16 +283,27 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||
acc = get!(o.acc, x, fill!(zero(x), ϵ))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADADelta(ρ = 0.9, ϵ = 1e-8)
|
||||
ADADelta(ρ)
|
||||
|
||||
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
tuning.
|
||||
Version of ADAGrad that adapts learning rate based on a window of past gradient updates. Parameters don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Rho (ρ): Factor by which gradient is decayed at each time step. Defaults to `0.9`.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = ADADelta() # uses default ρ = 0.9
|
||||
opt = ADADelta(0.89)
|
||||
```
|
||||
|
||||
## References
|
||||
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser.
|
||||
"""
|
||||
mutable struct ADADelta
|
||||
rho::Float64
|
||||
@ -175,10 +322,22 @@ function apply!(o::ADADelta, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999))
|
||||
AMSGrad(η, β::Tuple)
|
||||
|
||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||
tuning.
|
||||
Implements AMSGrad version of the ADAM optimiser. Parameters don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = AMSGrad() # uses default η and β
|
||||
opt = AMSGrad(0.001, (0.89, 0.995))
|
||||
```
|
||||
|
||||
## References
|
||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser.
|
||||
"""
|
||||
mutable struct AMSGrad
|
||||
eta::Float64
|
||||
@ -190,18 +349,30 @@ AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||
|
||||
function apply!(o::AMSGrad, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||
mt, vt, v̂t = get!(o.state, x, (fill!(zero(x), ϵ), fill!(zero(x), ϵ), fill!(zero(x), ϵ)))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
||||
@. v̂t = max.(v̂t, vt)
|
||||
@. v̂t = max(v̂t, vt)
|
||||
@. Δ = η * mt / (√v̂t + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
NADAM(η = 0.001, β = (0.9, 0.999))
|
||||
NADAM(η, β::Tuple)
|
||||
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
|
||||
tuning.
|
||||
Nesterov variant of ADAM. Parameters don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = NADAM() # uses default η and β
|
||||
opt = NADAM(0.002, (0.89, 0.995))
|
||||
```
|
||||
|
||||
## References
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser.
|
||||
"""
|
||||
mutable struct NADAM
|
||||
eta::Float64
|
||||
@ -213,8 +384,7 @@ NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||
|
||||
function apply!(o::NADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
β1p, β2p = o.beta
|
||||
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||
mt, vt, (β1p, β2p) = get!(o.state, x, (zero(x), zero(x), o.beta))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η
|
||||
@ -223,9 +393,23 @@ function apply!(o::NADAM, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
|
||||
ADAMW(η, β::Tuple, decay)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
Variant of ADAM defined by fixing weight decay regularization.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to (0.9, 0.999).
|
||||
- decay: Decay applied to weights during optimisation. Defaults to 0.
|
||||
|
||||
## Examples
|
||||
```julia
|
||||
opt = ADAMW() # uses default η, β and decay
|
||||
opt = ADAMW(0.001, (0.89, 0.995), 0.1)
|
||||
```
|
||||
|
||||
## References
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101)
|
||||
"""
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||
Optimiser(ADAM(η, β), WeightDecay(decay))
|
||||
@ -258,9 +442,14 @@ function apply!(o::Optimiser, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
`InvDecay(γ)`
|
||||
InvDecay(γ)
|
||||
|
||||
Apply inverse time decay to an optimiser
|
||||
Applies inverse time decay to an optimiser
|
||||
|
||||
## Parameters
|
||||
- gamma (γ): Defaults to `0.001`
|
||||
|
||||
## Example
|
||||
```julia
|
||||
Optimiser(InvDecay(..), Opt(..))
|
||||
```
|
||||
@ -281,13 +470,22 @@ function apply!(o::InvDecay, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
`ExpDecay(eta, decay, decay_step, clip)`
|
||||
ExpDecay(eta, decay, decay_step, clip)
|
||||
|
||||
Schedule the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
|
||||
Discount the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (eta): Defaults to `0.001`.
|
||||
- decay: Factor by which the learning rate is discounted. Defaults to `0.1`.
|
||||
- decay_step: Schedules decay operations by setting number of steps between two decay operations. Defaults to `1000`.
|
||||
- clip: Minimum value of learning rate. Defaults to `1e-4`.
|
||||
|
||||
## Example
|
||||
To apply exponential decay to an optimiser:
|
||||
```julia
|
||||
Optimiser(ExpDecay(..), Opt(..))
|
||||
|
||||
opt = Optimiser(ExpDecay(), ADAM())
|
||||
```
|
||||
"""
|
||||
mutable struct ExpDecay
|
||||
@ -311,9 +509,12 @@ function apply!(o::ExpDecay, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
`WeightDecay(wd)`
|
||||
WeightDecay(wd)
|
||||
|
||||
Decay the weight parameter by `wd`
|
||||
Decays the weight by `wd`
|
||||
|
||||
## Parameters
|
||||
- weight decay (wd): 0
|
||||
"""
|
||||
mutable struct WeightDecay
|
||||
wd::Real
|
||||
@ -323,5 +524,5 @@ WeightDecay() = WeightDecay(0)
|
||||
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * data(x)
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
@ -1,32 +1,29 @@
|
||||
using Juno
|
||||
import Flux.Tracker: Params, gradient, data, update!
|
||||
import Base.depwarn
|
||||
import Zygote: Params, gradient
|
||||
|
||||
function update!(x::AbstractArray, x̄)
|
||||
x .+= x̄
|
||||
return x
|
||||
end
|
||||
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, -apply!(opt, x, data(x̄)))
|
||||
x .-= apply!(opt, x, x̄)
|
||||
end
|
||||
|
||||
function update!(opt, xs::Params, gs)
|
||||
for x in xs
|
||||
gs[x] == nothing && continue
|
||||
update!(opt, x, gs[x])
|
||||
end
|
||||
end
|
||||
|
||||
# Added as an internal API but everyone started using it.
|
||||
function _update_params!(opt, xs)
|
||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||
for x in xs
|
||||
update!(opt, x, Tracker.grad(x))
|
||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||
end
|
||||
end
|
||||
|
||||
# Callback niceties
|
||||
call(f, xs...) = f(xs...)
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
struct StopException <: Exception end
|
||||
|
||||
"""
|
||||
stop()
|
||||
|
||||
@ -72,10 +69,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
||||
loss(d...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
end
|
||||
cb()
|
||||
catch ex
|
||||
if ex isa StopException
|
||||
break
|
||||
|
@ -1,87 +0,0 @@
|
||||
import Adapt: adapt, adapt_storage
|
||||
import .Tracker: IdSet
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
||||
children(x::Tuple) = x
|
||||
children(x::NamedTuple) = x
|
||||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
mapchildren(f, x::NamedTuple) = map(f, x)
|
||||
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||
end
|
||||
end
|
||||
|
||||
macro treelike(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
function mapleaves(f, x; cache = IdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
function prefor(f, x; seen = IdSet())
|
||||
x ∈ seen && return
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
return
|
||||
end
|
||||
|
||||
function params(m)
|
||||
ps = Params()
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||
m)
|
||||
return ps
|
||||
end
|
||||
|
||||
params(m...) = params(m)
|
||||
|
||||
function loadparams!(m, xs)
|
||||
for (p, x) in zip(params(m), xs)
|
||||
size(p) == size(x) ||
|
||||
error("Expected param size $(size(p)), got $(size(x))")
|
||||
copyto!(data(p), data(x))
|
||||
end
|
||||
end
|
||||
|
||||
# CPU/GPU movement conveniences
|
||||
|
||||
cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
||||
|
||||
gpu_adaptor = identity
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
global gpu_adaptor = CuArrays.cu
|
||||
end
|
||||
|
||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
||||
|
||||
# Precision
|
||||
|
||||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
||||
|
||||
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
||||
|
||||
f32(m) = paramtype(Float32, m)
|
||||
f64(m) = paramtype(Float64, m)
|
||||
|
||||
# General parameter map
|
||||
|
||||
function mapparams(f, m)
|
||||
mapleaves(m) do x
|
||||
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
||||
x isa Union{AbstractArray,Number} ? f(x) :
|
||||
x
|
||||
end
|
||||
end
|
@ -1,6 +1,11 @@
|
||||
# Arrays
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
|
||||
nfan() = 1, 1 #fan_in, fan_out
|
||||
nfan(n) = 1, n #A vector is treated as a n×1 matrix
|
||||
nfan(n_out, n_in) = n_in, n_out #In case of Dense kernels: arranged as matrices
|
||||
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) #In case of convolution kernels
|
||||
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
||||
|
||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||
|
@ -1,4 +1,5 @@
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux, Test
|
||||
using Flux.CuArrays
|
||||
using Flux: gpu
|
||||
|
||||
@info "Testing GPU Support"
|
||||
@ -7,11 +8,11 @@ using Flux: gpu
|
||||
|
||||
CuArrays.allowscalar(false)
|
||||
|
||||
x = param(randn(5, 5))
|
||||
x = randn(5, 5)
|
||||
cx = gpu(x)
|
||||
@test cx isa TrackedArray && cx.data isa CuArray
|
||||
@test cx isa CuArray
|
||||
|
||||
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
|
||||
@test Flux.onecold(gpu([1.0, 2.0, 3.0])) == 3
|
||||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@ -21,24 +22,33 @@ cx = gpu(x)
|
||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
cm = gpu(m)
|
||||
|
||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
@test all(p isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
|
||||
|
||||
x = [1,2,3]
|
||||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0)
|
||||
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
|
||||
|
||||
xs = param(rand(5,5))
|
||||
x = [-1.1491, 0.8619, 0.3127]
|
||||
y = [1, 1, 0.]
|
||||
@test Flux.binarycrossentropy.(σ.(x),y) ≈ Flux.binarycrossentropy.(cu(σ.(x)),cu(y))
|
||||
@test Flux.logitbinarycrossentropy.(x,y) ≈ Flux.logitbinarycrossentropy.(cu(x),cu(y))
|
||||
|
||||
xs = rand(5, 5)
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
x = gpu(rand(10, 10, 3, 2))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||
|
||||
c = gpu(CrossCor((2,2),3=>4))
|
||||
x = gpu(rand(10, 10, 3, 2))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||
|
||||
end
|
||||
|
||||
@ -48,10 +58,10 @@ end
|
||||
@test y[3,:] isa CuArray
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
if !haskey(ENV, "CI_DISABLE_CURNN_TEST")
|
||||
include("curnn.jl")
|
||||
end
|
||||
if CuArrays.has_cudnn()
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
include("curnn.jl")
|
||||
else
|
||||
@warn "CUDNN unavailable, not testing GPU DNN support"
|
||||
end
|
||||
|
@ -1,48 +1,44 @@
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux.Tracker: TrackedArray, data
|
||||
using Flux, CuArrays, Test
|
||||
using Flux: pullback
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
@testset "4D Input" begin
|
||||
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
|
||||
x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
|
||||
m = BatchNorm(3)
|
||||
cx = gpu(x)
|
||||
cm = gpu(m)
|
||||
|
||||
y = m(x)
|
||||
cy = cm(cx)
|
||||
y, back = pullback((m, x) -> m(x), m, x)
|
||||
cy, cback = pullback((m, x) -> m(x), cm, cx)
|
||||
|
||||
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||
@test cpu(cy) ≈ y
|
||||
|
||||
@test cpu(data(cy)) ≈ data(y)
|
||||
Δ = randn(size(y))
|
||||
dm, dx = back(Δ)
|
||||
cdm, cdx = cback(gpu(Δ))
|
||||
|
||||
g = rand(size(y)...)
|
||||
Flux.back!(y, g)
|
||||
Flux.back!(cy, gpu(g))
|
||||
|
||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||
@test x.grad ≈ cpu(x.grad)
|
||||
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||
@test dm[].β ≈ cpu(cdm[].β)
|
||||
@test dx ≈ cpu(cdx)
|
||||
end
|
||||
|
||||
@testset "2D Input" begin
|
||||
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
|
||||
x = Float64.(collect(reshape(1:12, 3, 4)))
|
||||
m = BatchNorm(3)
|
||||
cx = gpu(x)
|
||||
cm = gpu(m)
|
||||
|
||||
y = m(x)
|
||||
cy = cm(cx)
|
||||
y, back = pullback((m, x) -> m(x), m, x)
|
||||
cy, cback = pullback((m, x) -> m(x), cm, cx)
|
||||
|
||||
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
@test cpu(cy) ≈ y
|
||||
|
||||
@test cpu(data(cy)) ≈ data(y)
|
||||
Δ = randn(size(y))
|
||||
dm, dx = back(Δ)
|
||||
cdm, cdx = cback(gpu(Δ))
|
||||
|
||||
g = rand(size(y)...)
|
||||
Flux.back!(y, g)
|
||||
Flux.back!(cy, gpu(g))
|
||||
|
||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||
@test x.grad ≈ cpu(x.grad)
|
||||
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||
@test dm[].β ≈ cpu(cdm[].β)
|
||||
@test dx ≈ cpu(cdx)
|
||||
end
|
||||
end
|
||||
|
@ -1,46 +1,63 @@
|
||||
using Flux, CuArrays, Test
|
||||
using Flux: pullback
|
||||
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
m = R(10, 5) |> gpu
|
||||
x = gpu(rand(10))
|
||||
(m̄,) = gradient(m -> sum(m(x)), m)
|
||||
Flux.reset!(m)
|
||||
θ = gradient(() -> sum(m(x)), params(m))
|
||||
@test collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
|
||||
end
|
||||
|
||||
@testset "RNN" begin
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
||||
rnn = R(10, 5)
|
||||
curnn = mapleaves(gpu, rnn)
|
||||
@testset for batch_size in (1, 5)
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
x = batch_size == 1 ?
|
||||
param(rand(10)) :
|
||||
param(rand(10,batch_size))
|
||||
cux = gpu(x)
|
||||
y = (rnn(x); rnn(x))
|
||||
cuy = (curnn(cux); curnn(cux))
|
||||
curnn = fmap(gpu, rnn)
|
||||
|
||||
@test y.data ≈ collect(cuy.data)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
x = batch_size == 1 ?
|
||||
rand(10) :
|
||||
rand(10, batch_size)
|
||||
cux = gpu(x)
|
||||
|
||||
Δ = randn(size(y))
|
||||
y, back = pullback((r, x) -> r(x), rnn, x)
|
||||
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)
|
||||
|
||||
Flux.back!(y, Δ)
|
||||
Flux.back!(cuy, gpu(Δ))
|
||||
@test y ≈ collect(cuy)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
|
||||
@test x.grad ≈ collect(cux.grad)
|
||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
||||
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
||||
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
||||
if isdefined(rnn.cell, :c)
|
||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||
ȳ = randn(size(y))
|
||||
m̄, x̄ = back(ȳ)
|
||||
cum̄, cux̄ = cuback(gpu(ȳ))
|
||||
|
||||
m̄[].cell[].Wi
|
||||
|
||||
m̄[].state
|
||||
cum̄[].state
|
||||
|
||||
@test x̄ ≈ collect(cux̄)
|
||||
@test m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi)
|
||||
@test m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh)
|
||||
@test m̄[].cell[].b ≈ collect(cum̄[].cell[].b)
|
||||
if m̄[].state isa Tuple
|
||||
for (x, cx) in zip(m̄[].state, cum̄[].state)
|
||||
@test x ≈ collect(cx)
|
||||
end
|
||||
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
ohx = batch_size == 1 ?
|
||||
Flux.onehot(rand(1:10), 1:10) :
|
||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||
cuohx = gpu(ohx)
|
||||
y = (rnn(ohx); rnn(ohx))
|
||||
cuy = (curnn(cuohx); curnn(cuohx))
|
||||
|
||||
@test y.data ≈ collect(cuy.data)
|
||||
else
|
||||
@test m̄[].state ≈ collect(cum̄[].state)
|
||||
end
|
||||
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
ohx = batch_size == 1 ?
|
||||
Flux.onehot(rand(1:10), 1:10) :
|
||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||
cuohx = gpu(ohx)
|
||||
y = (rnn(ohx); rnn(ohx))
|
||||
cuy = (curnn(cuohx); curnn(cuohx))
|
||||
|
||||
@test y ≈ collect(cuy)
|
||||
end
|
||||
end
|
||||
|
@ -4,11 +4,13 @@ import Flux: activations
|
||||
@testset "basic" begin
|
||||
@testset "helpers" begin
|
||||
@testset "activations" begin
|
||||
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||
x = rand(10)
|
||||
@test activations(Chain(), x) == []
|
||||
@test activations(dummy_model, x)[1] == dummy_model[1](x)
|
||||
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
|
||||
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x))
|
||||
x = randn(10)
|
||||
@test activations(dummy_model, x)[1] == x.^2
|
||||
@test activations(dummy_model, x)[2] == (x.^2 .- 3)
|
||||
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3)
|
||||
|
||||
@test activations(Chain(), x) == ()
|
||||
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
|
||||
end
|
||||
end
|
||||
@ -19,6 +21,12 @@ import Flux: activations
|
||||
# numeric test should be put into testset of corresponding layer
|
||||
end
|
||||
|
||||
@testset "Activations" begin
|
||||
c = Chain(Dense(3,5,relu), Dense(5,1,relu))
|
||||
X = Float32.([1.0; 1.0; 1.0])
|
||||
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c))
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
|
@ -1,5 +1,6 @@
|
||||
using Flux, Test
|
||||
using Flux: maxpool, meanpool
|
||||
using Flux: gradient
|
||||
|
||||
@testset "Pooling" begin
|
||||
x = randn(Float32, 10, 10, 3, 2)
|
||||
@ -25,9 +26,9 @@ end
|
||||
@testset "asymmetric padding" begin
|
||||
r = ones(Float32, 28, 28, 1, 1)
|
||||
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
|
||||
m.weight.data[:] .= 1.0
|
||||
m.bias.data[:] .= 0.0
|
||||
y_hat = Flux.data(m(r))[:,:,1,1]
|
||||
m.weight[:] .= 1.0
|
||||
m.bias[:] .= 0.0
|
||||
y_hat = m(r)[:,:,1,1]
|
||||
@test size(y_hat) == (27, 29)
|
||||
@test y_hat[1, 1] ≈ 6.0
|
||||
@test y_hat[2, 2] ≈ 9.0
|
||||
@ -54,6 +55,10 @@ end
|
||||
y = Conv((3,3), 1 => 1)(x)
|
||||
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||
@test size(x_hat) == size(x)
|
||||
|
||||
m = ConvTranspose((3,3), 1=>1)
|
||||
# Test that the gradient call does not throw: #900
|
||||
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
|
||||
end
|
||||
|
||||
@testset "CrossCor" begin
|
||||
@ -102,4 +107,3 @@ end
|
||||
true
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -1,29 +1,29 @@
|
||||
using Flux: testmode!
|
||||
using Flux.Tracker: data
|
||||
using Flux, Test, Statistics
|
||||
using Zygote: pullback
|
||||
|
||||
trainmode(f, x...) = pullback(f, x...)[1]
|
||||
trainmode(f) = (x...) -> trainmode(f, x...)
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
@test x == testmode!(Dropout(0.1))(x)
|
||||
@test x == Dropout(0)(x)
|
||||
@test zero(x) == Dropout(1)(x)
|
||||
@test x == Dropout(0.1)(x)
|
||||
@test x == trainmode(Dropout(0), x)
|
||||
@test zero(x) == trainmode(Dropout(1), x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
y = m(x)
|
||||
y = trainmode(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
testmode!(m)
|
||||
y = m(x)
|
||||
@test count(a->a==0, y) == 0
|
||||
testmode!(m, false)
|
||||
y = m(x)
|
||||
y = trainmode(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
|
||||
x = rand(100)
|
||||
x = rand(Float32, 100)
|
||||
m = Chain(Dense(100,100),
|
||||
Dropout(0.9))
|
||||
y = m(x)
|
||||
y = trainmode(m, x)
|
||||
@test count(a->a == 0, y) > 50
|
||||
testmode!(m)
|
||||
y = m(x)
|
||||
@test count(a->a == 0, y) == 0
|
||||
|
||||
@ -39,18 +39,18 @@ using Flux.Tracker: data
|
||||
end
|
||||
|
||||
@testset "BatchNorm" begin
|
||||
let m = BatchNorm(2), x = param([1 3 5;
|
||||
2 4 6])
|
||||
let m = BatchNorm(2), x = [1.0 3.0 5.0;
|
||||
2.0 4.0 6.0]
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
@test length(params(m)) == 2
|
||||
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
# initial m.σ is 1
|
||||
# initial m.μ is 0
|
||||
@test m.active
|
||||
|
||||
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
||||
m(x)
|
||||
|
||||
y = trainmode(m, x)
|
||||
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
|
||||
# julia> x
|
||||
# 2×3 Array{Float64,2}:
|
||||
# 1.0 3.0 5.0
|
||||
@ -69,41 +69,32 @@ end
|
||||
# 2×1 Array{Float64,2}:
|
||||
# 1.3
|
||||
# 1.3
|
||||
@test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
@test m.σ² ≈ .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
x′ = m(x)
|
||||
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
|
||||
end
|
||||
|
||||
# with activation function
|
||||
let m = BatchNorm(2, sigmoid), x = param([1 3 5;
|
||||
2 4 6])
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
|
||||
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
|
||||
2.0 4.0 6.0]
|
||||
y = m(x)
|
||||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
@ -115,20 +106,18 @@ end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "InstanceNorm" begin
|
||||
# helper functions
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
# begin tests
|
||||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
@test length(params(m)) == 2
|
||||
x = Float64.(x)
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
y = trainmode(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
@ -153,37 +142,28 @@ end
|
||||
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||
@test m.μ ≈ [0.5, 0.8]
|
||||
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# julia> reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.
|
||||
# 1.
|
||||
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
@test m.σ² ≈ reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
x′ = m(x)
|
||||
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
x = Float64.(x)
|
||||
affine_shape = collect(sizes)
|
||||
affine_shape[1] = 1
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||
y = m(x)
|
||||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
@ -191,16 +171,16 @@ end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = m(x)
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
y = trainmode(m, x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
||||
@ -211,19 +191,20 @@ end
|
||||
|
||||
end
|
||||
|
||||
if VERSION >= v"1.1"
|
||||
@testset "GroupNorm" begin
|
||||
# begin tests
|
||||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
||||
|
||||
let m = GroupNorm(4,2), sizes = (3,4,2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
|
||||
@test m.β.data == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
|
||||
@test length(params(m)) == 2
|
||||
x = Float64.(x)
|
||||
@test m.β == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
y = trainmode(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
@ -253,21 +234,18 @@ end
|
||||
@test m.μ ≈ [0.95, 1.55]
|
||||
|
||||
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
|
||||
# 2-element Array{Tracker.TrackedReal{Float64},1}:
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.25
|
||||
# 1.25
|
||||
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
x′ = m(x)
|
||||
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
x = Float64.(x)
|
||||
μ_affine_shape = ones(Int,length(sizes) + 1)
|
||||
μ_affine_shape[end-1] = 2 # Number of groups
|
||||
|
||||
@ -279,20 +257,14 @@ end
|
||||
|
||||
og_shape = size(x)
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x)
|
||||
x_ = reshape(x,affine_shape...)
|
||||
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
|
||||
out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape)
|
||||
@test isapprox(y, out, atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
@ -300,23 +272,23 @@ end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = m(x)
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = trainmode(m, x)
|
||||
@test size(m.μ) == (m.G,1)
|
||||
@test size(m.σ²) == (m.G,1)
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
end
|
||||
end
|
||||
|
@ -51,13 +51,13 @@ const ϵ = 1e-7
|
||||
end
|
||||
|
||||
@testset "no spurious promotions" begin
|
||||
for T in (Float16, Float32, Float64)
|
||||
for T in (Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy)
|
||||
fwd, back = Flux.Tracker.forward(mse, ŷ, y)
|
||||
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
|
||||
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
|
||||
fwd, back = Flux.pullback(f, ŷ, y)
|
||||
@test fwd isa T
|
||||
@test eltype(back(one(T))[1]) == T
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -1,42 +1,44 @@
|
||||
using Flux.Optimise
|
||||
using Flux.Optimise: runall
|
||||
using Flux.Tracker
|
||||
using Flux: Params, gradient
|
||||
using Test
|
||||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
Momentum()]
|
||||
w′ = param(randn(10, 10))
|
||||
w′ = randn(10, 10)
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
for t = 1: 10^5
|
||||
θ = Params([w′])
|
||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
x = rand(10)
|
||||
θ̄ = gradient(() -> loss(x), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
@test loss(rand(10, 10)) < 0.01
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Optimiser" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||
w′ = param(randn(10, 10))
|
||||
w′ = randn(10, 10)
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
θ = Params([w′])
|
||||
x = rand(10)
|
||||
θ̄ = gradient(() -> loss(x), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
@test loss(rand(10, 10)) < 0.01
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Training Loop" begin
|
||||
i = 0
|
||||
l = param(1)
|
||||
l = 1
|
||||
|
||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
(),
|
||||
@ -57,17 +59,18 @@ end
|
||||
@testset "ExpDecay" begin
|
||||
w = randn(10, 10)
|
||||
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
|
||||
w1 = param(randn(10,10))
|
||||
w1 = randn(10,10)
|
||||
loss(x) = Flux.mse(w*x, w1*x)
|
||||
flag = 1
|
||||
decay_steps = []
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
prev_eta = o.eta
|
||||
prev_grad = collect(w1.grad)
|
||||
delta = Optimise.apply!(o, w1.data, w1.grad)
|
||||
w1.data .-= delta
|
||||
θ = Params([w1])
|
||||
x = rand(10)
|
||||
θ̄ = gradient(() -> loss(x), θ)
|
||||
prev_grad = collect(θ̄[w1])
|
||||
delta = Optimise.apply!(o, w1, θ̄[w1])
|
||||
w1 .-= delta
|
||||
new_eta = o.eta
|
||||
if new_eta != prev_eta
|
||||
push!(decay_steps, t)
|
||||
|
@ -1,11 +1,8 @@
|
||||
using Flux, Test, Random, Statistics
|
||||
using Flux, Test, Random, Statistics, Documenter
|
||||
using Random
|
||||
|
||||
Random.seed!(0)
|
||||
|
||||
# So we can use the system CuArrays
|
||||
insert!(LOAD_PATH, 2, "@v#.#")
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
@info "Testing Basics"
|
||||
@ -22,12 +19,14 @@ include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
|
||||
@info "Running Gradient Checks"
|
||||
|
||||
include("tracker.jl")
|
||||
|
||||
if Base.find_package("CuArrays") != nothing
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
|
||||
if VERSION >= v"1.2"
|
||||
doctest(Flux)
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -1,15 +0,0 @@
|
||||
using Flux, Test
|
||||
using Tracker: gradcheck
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||
|
||||
@testset "Tracker" begin
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
||||
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
||||
|
||||
end
|
@ -1,6 +1,6 @@
|
||||
using Flux
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack
|
||||
using StatsBase: std
|
||||
using Flux: throttle, nfan, glorot_uniform, glorot_normal, stack, unstack
|
||||
using StatsBase: var
|
||||
using Random
|
||||
using Test
|
||||
|
||||
@ -52,31 +52,30 @@ using Test
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Jacobian" begin
|
||||
A = param(randn(2,2))
|
||||
x = randn(2)
|
||||
m(x) = A*x
|
||||
y = m(x)
|
||||
J = jacobian(m,x)
|
||||
@test J ≈ A.data
|
||||
end
|
||||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
Random.seed!(0)
|
||||
|
||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||
for (n_in, n_out) in [(100, 100), (100, 400)]
|
||||
v = glorot_uniform(n_in, n_out)
|
||||
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
||||
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
||||
@testset "Fan in/out" begin
|
||||
@test nfan() == (1, 1) #For a constant
|
||||
@test nfan(100) == (1, 100) #For vector
|
||||
@test nfan(100, 200) == (200, 100) #For Dense layer
|
||||
@test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer
|
||||
@test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer
|
||||
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
|
||||
end
|
||||
|
||||
v = glorot_normal(n_in, n_out)
|
||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
@testset "glorot" begin
|
||||
# glorot_uniform and glorot_normal should both yield a kernel with
|
||||
# variance ≈ 2/(fan_in + fan_out)
|
||||
for dims ∈ [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
|
||||
for init ∈ [glorot_uniform, glorot_normal]
|
||||
v = init(dims...)
|
||||
fan_in, fan_out = nfan(dims...)
|
||||
σ2 = 2 / (fan_in + fan_out)
|
||||
@test 0.9σ2 < var(v) < 1.1σ2
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@ -85,6 +84,15 @@ end
|
||||
@test size.(params(m)) == [(5, 10), (5,)]
|
||||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Layer duplicated in same chain, params just once pls.
|
||||
c = Chain(m, m)
|
||||
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Self-referential array. Just want params, no stack overflow pls.
|
||||
r = Any[nothing,m]
|
||||
r[1] = r
|
||||
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
||||
@testset "Basic Stacking" begin
|
||||
@ -96,12 +104,11 @@ end
|
||||
@testset "Precision" begin
|
||||
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||
x = rand(10)
|
||||
@test eltype(m[1].W.data) == Float32
|
||||
@test eltype(m(x).data) == Float32
|
||||
@test eltype(f64(m)(x).data) == Float64
|
||||
@test eltype(f64(m)[1].W.data) == Float64
|
||||
@test eltype(f32(f64(m))[1].W.data) == Float32
|
||||
@test Tracker.isleaf(f32(f64(m))[1].W)
|
||||
@test eltype(m[1].W) == Float32
|
||||
@test eltype(m(x)) == Float32
|
||||
@test eltype(f64(m)(x)) == Float64
|
||||
@test eltype(f64(m)[1].W) == Float64
|
||||
@test eltype(f32(f64(m))[1].W) == Float32
|
||||
end
|
||||
|
||||
@testset "Stacking" begin
|
||||
|
Loading…
Reference in New Issue
Block a user