This commit is contained in:
Avik Pal 2019-01-24 18:42:28 +05:30
commit 73c1485927
51 changed files with 2391 additions and 855 deletions

2
.gitignore vendored
View File

@ -3,6 +3,4 @@
*.jl.mem *.jl.mem
docs/build/ docs/build/
docs/site/ docs/site/
docs/flux.css
deps deps
Manifest.toml

View File

@ -1,19 +1,25 @@
# Documentation: http://docs.travis-ci.com/user/languages/julia/ # Documentation: http://docs.travis-ci.com/user/languages/julia/
language: julia language: julia
os: os:
- linux - linux
# - osx # - osx
julia: julia:
- 0.7
- 1.0 - 1.0
- nightly - nightly
# uncomment the following lines to override the default test script
# script:
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
matrix: matrix:
allow_failures: allow_failures:
- julia: nightly - julia: nightly
after_success:
- julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")' jobs:
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))' include:
- stage: "Documentation"
julia: 1.0
os: linux
script:
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
Pkg.instantiate()'
- julia --project=docs/ docs/make.jl
after_success: skip

272
Manifest.toml Normal file
View File

@ -0,0 +1,272 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.2"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.2"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.4"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.0"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"

25
Project.toml Normal file
View File

@ -0,0 +1,25 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"

View File

@ -1,9 +1,9 @@
julia 0.7 julia 1.0
Juno Juno
MacroTools 0.3.3 MacroTools 0.3.3
NNlib NNlib
Requires Requires
Adapt Adapt 0.4
CodecZlib CodecZlib
Colors Colors
ZipFile ZipFile

288
docs/Manifest.toml Normal file
View File

@ -0,0 +1,288 @@
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.6.0"
[[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.21.0"
[[FixedPointNumbers]]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
path = ".."
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.6.10+"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"

4
docs/Project.toml Normal file
View File

@ -0,0 +1,4 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

View File

@ -2,10 +2,11 @@ using Documenter, Flux, NNlib
makedocs(modules=[Flux, NNlib], makedocs(modules=[Flux, NNlib],
doctest = false, doctest = false,
format = :html,
analytics = "UA-36890222-9", analytics = "UA-36890222-9",
sitename = "Flux", sitename = "Flux",
assets = ["../flux.css"], # Uncomment below for local build
#format = Documenter.HTML(prettyurls = false),
assets = ["assets/flux.css"],
pages = ["Home" => "index.md", pages = ["Home" => "index.md",
"Building Models" => "Building Models" =>
["Basics" => "models/basics.md", ["Basics" => "models/basics.md",
@ -22,10 +23,4 @@ makedocs(modules=[Flux, NNlib],
["Backpropagation" => "internals/tracker.md"], ["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"]) "Community" => "community.md"])
deploydocs( deploydocs(repo = "github.com/FluxML/Flux.jl.git")
repo = "github.com/FluxML/Flux.jl.git",
target = "build",
osname = "linux",
julia = "1.0",
deps = nothing,
make = nothing)

113
docs/src/assets/flux.css Normal file
View File

@ -0,0 +1,113 @@
@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');
body {
font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
}
nav.toc {
padding-top: 0;
background: rgb(240, 240, 240);
line-height: 2em;
cursor: default;
user-select: none;
}
h1+h2 {
margin-top: 0;
}
/* Green banner in ToC */
nav.toc > h1 {
margin-top: 0;
padding-top: 0.4em;
padding-bottom: 0.5em;
border-bottom: 5px solid white;
box-shadow: 0px -2px 5px rgb(60,60,60);
margin-bottom: 0.5em;
background: rgb(60, 150, 60);
font-style: italic;
font-weight: normal;
font-size: 50pt;
text-transform: lowercase;
text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
color: white;
}
/* Reduce ToC font size */
.toctext {
font-size: 10pt;
}
/* Fade out non-clickable ToC headers */
nav.toc ul span.toctext {
color: rgb(180, 180, 180);
}
nav.toc ul .toctext {
color: rgb(100, 100, 100);
}
nav.toc ul a.toctext:hover {
color: inherit;
background: rgb(220, 220, 220);
cursor: default;
}
nav.toc li.current > .toctext {
background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
font-weight: normal;
}
nav.toc ul.internal li.toplevel {
font-weight: normal;
}
/* Content */
article { max-width: none; }
article > p, article > ul {
max-width: 45em;
}
/* Links */
a, a:visited { color: rgb(0, 120, 0); }
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
a:hover, a:visited:hover { color: rgb(0, 80, 0); }
/* Article Links */
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }
/* Doctstrings */
article section.docstring {
padding: 0.5em 0;
border-left: none;
border-right: none;
border-bottom: none;
}
/* Code */
article pre, article p > code {
background: rgb(245, 250, 245);
}
article pre {
border: none;
max-width: none;
padding: 1em;
border-radius: 10px 0px 0px 10px;
margin-left: -1em;
margin-right: -2em;
}
.hljs-comment {
font-style: italic;
}
.hljs-number {
color: rgb(0, 150, 150);
}

View File

@ -4,7 +4,7 @@ Support for array operations on other hardware backends, like GPUs, is provided
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.) (Note that you need to have CUDA available to use CuArrays please see the [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl) instructions for more details.)
```julia ```julia
using CuArrays using CuArrays

View File

@ -100,16 +100,16 @@ minus(a, b) = a - b
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia ```julia
using Flux.Tracker: TrackedReal, track, @grad using Flux.Tracker: TrackedArray, track, @grad
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
``` ```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition. `track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia ```julia
@grad function minus(a, b) @grad function minus(a, b)
return minus(data(a),data(b)), Δ -> (Δ, -Δ) return minus(data(a), data(b)), Δ -> (Δ, -Δ)
end end
``` ```
@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ) @grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
``` ```
We can then calculate the first derivative of `minus` as follows:
```julia
a = param([1,2,3])
b = param([3,2,1])
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
Tracker.back!(c, 1)
Tracker.grad(a) # [1.00, 1.00, 1.00]
Tracker.grad(b) # [-1.00, -1.00, -1.00]
```
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia ```julia

View File

@ -28,7 +28,7 @@ When a function has many parameters, we can pass them all in explicitly:
f(W, b, x) = W * x + b f(W, b, x) = W * x + b
Tracker.gradient(f, 2, 3, 4) Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0, 2.0 (tracked)) (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
``` ```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once. But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
@ -102,6 +102,8 @@ All deep learning in Flux, however complex, is a simple generalisation of this e
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as: It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as:
```julia ```julia
using Flux
W1 = param(rand(3, 5)) W1 = param(rand(3, 5))
b1 = param(rand(3)) b1 = param(rand(3))
layer1(x) = W1 * x .+ b1 layer1(x) = W1 * x .+ b1

View File

@ -10,6 +10,12 @@ MaxPool
MeanPool MeanPool
``` ```
## Additional Convolution Layers
```@docs
DepthwiseConv
```
## Recurrent Layers ## Recurrent Layers
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data). Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).

View File

@ -23,44 +23,30 @@ We want to update each parameter, using the gradient, in order to improve (reduc
```julia ```julia
using Flux.Tracker: grad, update! using Flux.Tracker: grad, update!
function sgd() η = 0.1 # Learning Rate
η = 0.1 # Learning Rate for p in (W, b)
for p in (W, b) update!(p, -η * grads[p])
update!(p, -η * grads[p])
end
end end
``` ```
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down. Running this will alter the parameters `W` and `b` and our loss should go down. Flux provides a more general way to do optimiser updates like this.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
In this case, getting the variables is trivial, but you can imagine it'd be more of a pain with some complex stack of layers.
```julia ```julia
m = Chain( opt = Descent(0.1) # Gradient descent with learning rate 0.1
Dense(10, 5, σ),
Dense(5, 2), softmax) for p in (W, b)
update!(opt, p, -η * grads[p])
end
``` ```
Instead of having to write `[m[1].W, m[1].b, ...]`, Flux provides a params function `params(m)` that returns a list of all parameters in the model for you. An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `ADAM`.
For the update step, there's nothing whatsoever wrong with writing the loop above it'll work just fine but Flux provides various *optimisers* that make it more convenient.
```julia
opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
opt() # Carry out the update, modifying `W` and `b`.
```
An optimiser takes a parameter list and returns a function that does the same thing as `update` above. We can pass either `opt` or `update` to our [training loop](training.md), which will then run the optimiser after every mini-batch of data.
## Optimiser Reference ## Optimiser Reference
All optimisers return a function that, when called, will update the parameters passed to it. All optimisers return an object that, when passed to `train!`, will update the parameters passed to it.
```@docs ```@docs
SGD Descent
Momentum Momentum
Nesterov Nesterov
ADAM ADAM

View File

@ -9,7 +9,7 @@ To actually train a model we need three things:
With these we can call `Flux.train!`: With these we can call `Flux.train!`:
```julia ```julia
Flux.train!(objective, data, opt) Flux.train!(objective, params, data, opt)
``` ```
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo). There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
@ -24,9 +24,10 @@ m = Chain(
Dense(32, 10), softmax) Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y) loss(x, y) = Flux.mse(m(x), y)
ps = Flux.params(m)
# later # later
Flux.train!(loss, data, opt) Flux.train!(loss, ps, data, opt)
``` ```
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want. The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
@ -78,7 +79,7 @@ julia> @epochs 2 Flux.train!(...)
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
```julia ```julia
train!(objective, data, opt, cb = () -> println("training")) train!(objective, ps, data, opt, cb = () -> println("training"))
``` ```
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
@ -89,6 +90,6 @@ A more typical callback might look like this:
test_x, test_y = # ... create single batch of test data ... test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y)) evalcb() = @show(loss(test_x, test_y))
Flux.train!(objective, data, opt, Flux.train!(objective, ps, data, opt,
cb = throttle(evalcb, 5)) cb = throttle(evalcb, 5))
``` ```

View File

@ -2,11 +2,12 @@ module Flux
# Zero Flux Given # Zero Flux Given
using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm, DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu
@reexport using NNlib @reexport using NNlib
@ -19,8 +20,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -1,7 +1,37 @@
module CUDA module CUDA
using ..CuArrays using ..CuArrays
using Pkg.TOML
CuArrays.cudnn_available() && include("cudnn.jl") function version_check()
minor_version = 9
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0"))
if !(version.major == 0 && version.minor == minor_version)
@warn """
Flux is only supported with CuArrays v0.$minor_version.
Try running `] pin CuArrays@0.$minor_version`.
"""
end
end
version_check()
if !applicable(CuArray{UInt8}, undef, 1)
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
end
if CuArrays.libcudnn != nothing
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
include("curnn.jl")
include("cudnn.jl")
else
@warn("CUDNN is not installed, some functionality will not be available.")
end
end end

View File

@ -1,6 +1,6 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnDataType, TensorDesc, FilterDesc cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
import ..Flux: data
using LinearAlgebra using LinearAlgebra
mutable struct DropoutDesc mutable struct DropoutDesc
@ -14,335 +14,215 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL] d = [C_NULL]
s = Csize_t[0] s = Csize_t[0]
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d) @check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s) @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0? states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states) desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong), @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
desc,libcudnn_handle[],ρ,states,length(states),seed) desc,handle(),ρ,states,length(states),seed)
finalizer(desc) do x finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end end
return desc return desc
end end
const RNN_RELU = 0 # Stock RNN with ReLu activation const BATCHNORM_SPATIAL = 1
const RNN_TANH = 1 # Stock RNN with tanh activation const BATCHNORM_ACTIVATION = 0
const LSTM = 2 # LSTM with no peephole connections const BATCHNORM_MIN_EPS = 1e-5
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0 @inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0 @inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0 mutable struct BNCache
const RNN_ALGO_PERSIST_STATIC = 1 mean
const RNN_ALGO_PERSIST_DYNAMIC = 2 ivar
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
(wx, wh), bias
end end
mutable struct RNNDesc{T} BNCache() = BNCache(nothing, nothing)
mode::Int
input::Int # NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
hidden::Int # so reshape a 2D Tensor into 4D
params::CuVector{T} batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
weights::NTuple{2,CuMatrix{T}} running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
bias::CuVector{T} cache = nothing, alpha = T(1), beta = T(0),
ptr::Ptr{Nothing} eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
y = similar(x)
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
alpha = alpha, beta = beta, eps = eps, training = training)
y
end end
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
function rnnParamSize(T, r, input) momentum; cache = nothing,
size = Csize_t[0] alpha = T(1), beta = T(0),
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint), eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) dims = _wsize(x)
return Int(size[])÷sizeof(T) if eps < BATCHNORM_MIN_EPS
end # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
eps = BATCHNORM_MIN_EPS
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL]
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end end
return rd xd = TensorDesc(x)
end yd = TensorDesc(y)
gd = TensorDesc(T, dims)
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) if training
size = Csize_t[0]
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[])
end
const workspace = [CuVector{UInt8}(1)] if cache !== nothing
mean = zeros(CuArray{T}, dims...)
ivar = ones(CuArray{T}, dims...)
else
mean = C_NULL
ivar = C_NULL
end
getworkspace(bytes) = @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
length(workspace[]) bytes ? (cudnnHandle_t,cudnnBatchNormMode_t,
workspace[] : Ptr{T}, Ptr{T},
(workspace[] = CuVector{UInt8}(bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[])
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing) where T
if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t), Ptr{Nothing}, Ptr{T},
libcudnn_handle[], rnn, seqlen, Ptr{Nothing}, Ptr{T}, Ptr{T},
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, Cdouble, Ptr{T}, Ptr{T},
workspace, length(workspace)) Cdouble, Ptr{T}, Ptr{T}),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
yd, y,
gd, g, b,
momentum, running_mean, running_var,
eps, mean, ivar)
if cache !== nothing
cache.mean = mean
cache.ivar = ivar
end
else else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, (Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), Ptr{Nothing}, Ptr{T},
libcudnn_handle[], rnn, seqlen, Ptr{Nothing}, Ptr{T},
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, Ptr{Nothing}, Ptr{T}, Ptr{T},
workspace, length(workspace), reserve, length(reserve)) Ptr{T}, Ptr{T},
Cdouble),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
yd, y,
gd, g, b,
running_mean, running_var,
eps)
end end
end end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
hDesc(h::Nothing) = C_NULL, C_NULL cache = nothing, eps = T(1e-5), alpha = T(1),
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) beta = T(0), training = true) where T<:Union{Float32, Float64}
function hDesc(h::CuArray) dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
alpha = alpha, beta = beta, training = training)
(dg, db, dropdims(dx, dims = (1, 2)))
end end
# TODO: can we just manipulate strides here? function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
# TODO: should use repmat, but this isn't implemented. running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
hBatch(x::AbstractVector, h::CuVector) = h cache = nothing, eps = T(1e-5), alpha = T(1),
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) beta = T(0), training = true) where T<:Union{Float32, Float64}
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) dg = similar(g)
db = similar(b)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T dx = similar(x)
h = hBatch(x, h_) cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
c = c_ == nothing ? nothing : hBatch(x, c_) training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
@assert size(x, 1) == rnn.input (dg, db, dx)
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
xdesc = xDesc(x)
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ho = similar(h)
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T = function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
forward(rnn, x, h, c, Val{true}) dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum; cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
if training
xd = TensorDesc(x)
dyd = TensorDesc(dy)
dxd = TensorDesc(dx)
gd = TensorDesc(T, _wsize(x))
if cache !== nothing
mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache")
else
mean, ivar = C_NULL, C_NULL
end
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, if eps < BATCHNORM_MIN_EPS
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T eps = BATCHNORM_MIN_EPS
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, end
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
# Same as above, any more efficient way? (cudnnHandle_t,cudnnBatchNormMode_t,
dy = dy_ isa Integer ? zero(y) : dy_ Ptr{T}, Ptr{T},
yd = xDesc(y) Ptr{T}, Ptr{T},
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) Ptr{Nothing}, Ptr{T},
dh = similar(h) Ptr{Nothing}, Ptr{T},
dc = c == nothing ? nothing : similar(c) Ptr{Nothing}, Ptr{T},
cudnnRNNBackwardData(rnn, 1, Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., Cdouble, Ptr{T}, Ptr{T}),
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, handle(), BATCHNORM_SPATIAL,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., Ref(T(alpha)), Ref(T(beta)),
workspace[], reserve) Ref(T(dalpha)), Ref(T(dbeta)),
return c == nothing ? (dx, dh) : (dx, dh, dc) xd, x,
end dyd, dy,
dxd, dx,
backwardData(rnn, y, dy, dho, hx, reserve) = gd, g, dg, db,
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) eps, mean, ivar)
else
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
workspace, reserve) where T dx .= dy .* reshape(g, _wsize(x)) .* ivar
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zero(rnn.params)
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end
# Interface
import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
return
end
function RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
end end
end end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) # Flux Interface
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ) (BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
y, ho = result batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
dy, dho, dco = Δ
h_ = hBatch(x, data(h)) batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
c_ = hBatch(x, data(c)) running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
transpose(dWi), transpose(dWh), db)) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
end
end batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)

325
src/cuda/curnn.jl Normal file
View File

@ -0,0 +1,325 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
(wx, wh), bias
end
mutable struct RNNDesc{T}
mode::Int
input::Int
hidden::Int
params::CuVector{T}
weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T}
ptr::Ptr{Nothing}
end
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input)
size = Csize_t[0]
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL]
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
return rd
end
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
handle(), r, seqlen, xdesc, size)
return Int(size[])
end
const workspace = [CuVector{UInt8}(undef, 1)]
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(undef, bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
handle(), r, seqlen, xdesc, size)
return Int(size[])
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing) where T
if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
end
end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, C_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
xdesc = xDesc(x)
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ho = similar(h)
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
dy = dy_ isa Integer ? zero(y) : dy_
yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
cudnnRNNBackwardData(rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], reserve)
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zero(rnn.params)
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end
# Interface
import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
return
end
function RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
end
end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
transpose(dWi), transpose(dWh), db))
end
end

View File

@ -13,6 +13,9 @@ end
include("mnist.jl") include("mnist.jl")
export MNIST export MNIST
include("fashion-mnist.jl")
export FashionMNIST
include("cmudict.jl") include("cmudict.jl")
using .CMUDict using .CMUDict

64
src/data/fashion-mnist.jl Normal file
View File

@ -0,0 +1,64 @@
module FashionMNIST
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
function load()
mkpath(dir)
cd(dir) do
for file in ["train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte"]
isfile(file) && continue
@info "Downloading Fashion-MNIST dataset"
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz")
open(file, "w") do io
write(io, gzopen(read, "$file.gz"))
end
end
end
end
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
"""
images()
images(:test)
Load the Fashion-MNIST images.
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
Returns the 60,000 training images by default; pass `:test` to retreive the
10,000 test images.
"""
function images(set = :train)
load()
io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES))
_, N, nrows, ncols = imageheader(io)
[rawimage(io) for _ in 1:N]
end
"""
labels()
labels(:test)
Load the labels corresponding to each of the images returned from `images()`.
Each label is a number from 0-9.
Returns the 60,000 training labels by default; pass `:test` to retreive the
10,000 test labels.
"""
function labels(set = :train)
load()
io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS))
_, N = labelheader(io)
[rawlabel(io) for _ = 1:N]
end
end

View File

@ -16,19 +16,21 @@ m(x) == m[2](m[1](x))
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`. `Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers. `m[1:3](x)` will calculate the output of the first three layers.
""" """
struct Chain struct Chain{T<:Tuple}
layers::Vector{Any} layers::T
Chain(xs...) = new([xs...]) Chain(xs...) = new{typeof(xs)}(xs)
end end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push! @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
@forward Chain.layers Base.iterate Base.iterate, Base.lastindex
children(c::Chain) = c.layers children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...) mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x) applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
(c::Chain)(x) = applychain(c.layers, x)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
@ -75,7 +77,7 @@ end
@treelike Dense @treelike Dense
function (a::Dense)(x) function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b) σ.(W*x .+ b)
end end
@ -114,3 +116,11 @@ end
function Base.show(io::IO, l::Diagonal) function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")") print(io, "Diagonal(", length(l.α), ")")
end end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

View File

@ -1,4 +1,4 @@
using NNlib: conv using NNlib: conv, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2))) @generated sub2(::Val{N}) where N = :(Val($(N-2)))
@ -30,14 +30,14 @@ Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} = stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
stride = 1, pad = 0, dilation = 1) where N = init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@treelike Conv @treelike Conv
function (c::Conv)(x) function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :( # TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@ -51,6 +51,62 @@ function Base.show(io::IO, l::Conv)
print(io, ")") print(io, ")")
end end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
DepthwiseConv(size, in)
DepthwiseConv(size, in=>mul)
DepthwiseConv(size, in=>mul, relu)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `mul` specify the number of input channels and channel multiplier respectively.
In case the `mul` is not specified it is taken as 1.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct DepthwiseConv{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
end
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
stride = 1, pad = 0) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
stride = stride, pad = pad)
@treelike DepthwiseConv
function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
end
function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
""" """
MaxPool(k) MaxPool(k)
@ -60,9 +116,9 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad` and `stride`.
""" """
struct MaxPool{N} struct MaxPool{N}
k::NTuple{N,Int} k::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{N,Int}
stride::NTuple{N,Int} stride::NTuple{N,Int}
end end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =

View File

@ -44,7 +44,6 @@ end
_testmode!(a::Dropout, test) = (a.active = !test) _testmode!(a::Dropout, test) = (a.active = !test)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
@ -86,7 +85,6 @@ See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
Example: Example:
```julia ```julia
m = Chain( m = Chain(
Dense(28^2, 64), Dense(28^2, 64),
@ -101,14 +99,14 @@ mutable struct BatchNorm{F,V,W,N}
β::V # bias β::V # bias
γ::V # scale γ::V # scale
μ::W # moving mean μ::W # moving mean
σ::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Bool active::Bool
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) = initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum, true)
@ -124,31 +122,31 @@ function (BN::BatchNorm)(x)
if !BN.active if !BN.active
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ = reshape(BN.σ, affine_shape...) σ² = reshape(BN.σ², affine_shape...)
else else
T = eltype(x) T = eltype(x)
ϵ = data(convert(T, BN.ϵ)) ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes) μ = mean(x, dims = axes)
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ) σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
# update moving mean/std # update moving mean/std
mtm = data(convert(T, BN.momentum)) mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1) BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
end end
let λ = BN.λ let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
end end
end end
children(BN::BatchNorm) = children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test) _testmode!(BN::BatchNorm, test) = (BN.active = !test)

View File

@ -148,7 +148,7 @@ Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
""" """
LSTM(in::Integer, out::Integer, σ = tanh) LSTM(in::Integer, out::Integer)
Long Short Term Memory recurrent layer. Behaves like an RNN but generally Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences. exhibits a longer memory span over sequences.
@ -189,7 +189,7 @@ Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
""" """
GRU(in::Integer, out::Integer, σ = tanh) GRU(in::Integer, out::Integer)
Gated Recurrent Unit layer. Behaves like an RNN but generally Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences. exhibits a longer memory span over sequences.

View File

@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ
# Cost functions # Cost functions
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2) * 1 // length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2) -sum(y .* log.() .* weight) * 1 // size(y, 2)
end end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2) return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
end end
""" """

View File

@ -28,9 +28,9 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import Adapt.adapt import Adapt: adapt, adapt_structure
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin @init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
import .CuArrays: CuArray, cudaconvert import .CuArrays: CuArray, cudaconvert
@ -68,3 +68,6 @@ end
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b) a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b) a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)

View File

@ -1,23 +1,12 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
struct Param{T}
x::T
Δ::T
end
Param(x::AbstractArray) = Param(x, zero(x))
include("optimisers.jl") include("optimisers.jl")
include("interface.jl")
include("train.jl") include("train.jl")
include("deprecations.jl")
using Flux.Tracker: TrackedArray
Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end end

View File

@ -0,0 +1,126 @@
using Base: depwarn
using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
ps = params
opt = Descent(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
opt = Momentum(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
opt = Nesterov(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
opt = RMSProp(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
β = (β1, β2)
opt = ADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
opt = ADAGrad(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
opt = ADADelta(ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
β = (β1, β2)
opt = AdaMax(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
β = (β1, β2)
opt = AMSGrad(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
β = (β1, β2)
opt = NADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params
β = (β1, β2)
opt = ADAMW(η, β)
opt = check_decay(opt, decay)
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
updaterule(opt, ps)
end
# Old training loop
struct OldOptimiser
func
end
update!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
train!(loss, (), data, OldOptimiser(opt); cb = cb)
end

View File

@ -1,110 +0,0 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weights update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
os = map(f -> f(p), fs)
() -> foreach(call, os)
end
() -> foreach(call, fs)
end
"""
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -1,130 +1,327 @@
function descent(p::Param, η::Real) using Flux
function () using Base: @get!
@. p.x -= η * p.Δ using MacroTools: @forward
@. p.Δ = 0
const ϵ = 1e-8
# TODO: should use weak refs
"""
Descent(η)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
"""
mutable struct Descent
eta::Float64
end
Descent() = Descent(0.1)
function update!(o::Descent, x, Δ)
Δ .*= o.eta
end
"""
Momentum(params, η = 0.01; ρ = 0.9)
Gradient descent with learning rate `η` and momentum `ρ`.
"""
mutable struct Momentum
eta::Float64
rho::Float64
velocity::IdDict
end
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ
@. Δ = -v
end
"""
Nesterov(eta, ρ = 0.9)
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
"""
mutable struct Nesterov
eta::Float64
rho::Float64
velocity::IdDict
end
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
end
"""
RMSProp(η = 0.001, ρ = 0.9)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
mutable struct RMSProp
eta::Float64
rho::Float64
acc::IdDict
end
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function update!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ)
end
"""
ADAM(η = 0.001, β = (0.9, 0.999))
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
mutable struct ADAM
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
function update!(o::ADAM, x, Δ)
η, β = o.eta, o.beta
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β)
return Δ
end
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
mutable struct AdaMax
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. ut = max(β[2] * ut, abs(Δ))
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
o.state[x] = (mt, ut, βp .* β)
return Δ
end
"""
ADAGrad(η = 0.1; ϵ = 1e-8)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
mutable struct ADAGrad
eta::Float64
acc::IdDict
end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2
@. Δ *= η / (acc + ϵ)
end
"""
ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
mutable struct ADADelta
rho::Float64
state::IdDict
end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= Δacc/ (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
return Δ
end
"""
AMSGrad(η = 0.001, β = (0.9, 0.999))
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
mutable struct AMSGrad
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. Δ = η * mt / (v̂t + ϵ)
end
"""
NADAM(η = 0.001, β = (0.9, 0.999))
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
tuning.
"""
mutable struct NADAM
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ)
η, β = o.eta, o.beta
β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x)))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
return Δ
end
"""
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(decay))
# Compose optimizers
"""
Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
usual.
"""
mutable struct Optimiser
os::Vector{Any}
end
Optimiser(o...) = Optimiser(Any[o...])
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Optimiser.os Base.iterate
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
end end
return Δ
end end
# Ref: https://arxiv.org/abs/1711.05101.pdf """
function descentweightdecay(p::Param, η::Real, γ::Real) `InvDecay(γ)`
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x) Apply inverse time decay to an optimiser
@. p.Δ = 0 ```julia
Optimiser(InvDecay(..), Opt(..))
```
"""
mutable struct InvDecay
gamma::Float64
state::IdDict
end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ)
γ = o.gamma
n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n)
o.state[x] = n + 1
return Δ
end
"""
`ExpDecay(eta, decay, decay_step, clip)`
Schedule the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
To apply exponential decay to an optimiser:
```julia
Optimiser(ExpDecay(..), Opt(..))
```
"""
mutable struct ExpDecay
eta::Float64
decay::Float64
step::Int64
clip::Float64
current::IdDict
end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip)
o.eta = η
end end
@. Δ *= decay
end end
function momentum(p::Param, ρ, η) """
v = zero(p.x) `WeightDecay(wd)`
function ()
@. v = ρ * v - η * p.Δ Decay the weight parameter by `wd`
@. p.Δ = -v """
end mutable struct WeightDecay
wd::Real
end end
# Ref. https://arxiv.org/pdf/1212.0901.pdf WeightDecay() = WeightDecay(0)
function nesterov(p::Param, ρ, η)
v = zero(p.x)
function ()
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ
@. p.Δ = -d
end
end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) function update!(o::WeightDecay, x, Δ)
acc = zero(p.x) wd = o.wd
function () @. Δ += wd * x
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
end
end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zero(p.x) .+ ϵ
function ()
@. acc += p.Δ^2
@. p.Δ *= η / (acc + ϵ)
end
end
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zero(p.x)
Δacc = zero(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end
end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
vt = zero(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
ut = zero(p.x)
β1p = β1
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. ut = max(β2 * ut, abs(p.Δ))
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
vt = zero(p.x) .+ ϵ
v̂t = zero(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
vt = zero(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
end end

View File

@ -1,7 +1,17 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: data, grad, back!
import Base.depwarn import Base.depwarn
function update!(opt, xs)
for x in xs
Δ = update!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
end
end
# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -35,7 +45,7 @@ function stop()
end end
""" """
train!(loss, data, opt) train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`. backpropagation and calls the optimizer `opt`.
@ -44,22 +54,22 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds: every 10 seconds:
```julia ```julia
Flux.train!(loss, data, opt, Flux.train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10)) cb = throttle(() -> println("training"), 10))
``` ```
The callback can return `:stop` to interrupt the training loop. The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, data, opt; cb = () -> ()) function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
try try
l = loss(d...) l = loss(d...)
@interrupts back!(l) @interrupts back!(l)
opt() update!(opt, ps)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break

View File

@ -5,7 +5,8 @@ using MacroTools: @q, @forward
import Base: == import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back! export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back!
tracker(x) = nothing tracker(x) = nothing
@ -60,17 +61,11 @@ macro grad(ex)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end end
function update!(x, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
include("idset.jl") include("idset.jl")
include("back.jl") include("back.jl")
include("scalar.jl")
include("array.jl")
include("numeric.jl") include("numeric.jl")
include("lib/real.jl")
include("lib/array.jl")
""" """
hook(f, x) -> x hook(f, x) -> x
@ -99,7 +94,8 @@ end
nobacksies(f, x) = track(nobacksies, f, x) nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f") @grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
param(x::Number) = TrackedReal(float(x)) param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs)) param(xs::AbstractArray) = TrackedArray(float.(xs))
@ -108,10 +104,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs))
param(x::TrackedReal) = track(identity, x) param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x) param(x::TrackedArray) = track(identity, x)
import NNlib.cudata import Adapt: adapt, adapt_structure
import Adapt.adapt
cudata(x::TrackedArray) = data(x) adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end end

View File

@ -19,62 +19,87 @@ function scan(x)
return return
end end
function back_(c::Call, Δ) function back_(c::Call, Δ, once)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs)) foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end end
back_(::Call{Nothing}, Δ) = nothing back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ) function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad) grad = if isdefined(x, :grad)
if isdefined(x, :grad) x.grad = accum!(x.grad, Δ)
x.grad = accum!(x.grad, Δ) elseif ref > 0
else x.grad = Δ
x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
else else
ref == 0 && back_(x.f, Δ) Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end end
return return
end end
back(::Nothing, _) = return back(::Nothing, Δ, once) = return
# Interface methods # Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken # TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update. # and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called # Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator) # from within a backpropagator)
function back!(x, Δ) function back!(x, Δ; once = true)
istracked(x) || return istracked(x) || return
scan(x) scan(x)
back(tracker(x), Δ) back(tracker(x), Δ, once)
return return
end end
function gradient_(f, xs...)
xs = param.(data.(xs))
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end
# Out-of-place gradients # Out-of-place gradients
struct Params struct Params
params::IdSet order::Vector{Any}
Params(xs) = new(IdSet(xs)) params::IdSet{Any}
Params() = new([], IdSet())
end end
@forward Params.params Base.iterate, Base.length @forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params) function Base.show(io::IO, ps::Params)
print(io, "Params([") print(io, "Params([")
join(io, ps.params, ", ") join(io, ps.order, ", ")
print(io, "])") print(io, "])")
end end
@ -91,12 +116,12 @@ Grads() = Grads(IdDict())
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x] Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x) function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x") istracked(x) || error("Object not tracked: $x")
g[tracker(x)] g[tracker(x)]
end end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)
@ -146,20 +171,13 @@ function losscheck(x)
isnan(x) && error("Loss is NaN") isnan(x) && error("Loss is NaN")
end end
function gradient(f, args...) function gradient_nested(f, args...)
y, back = forward(f, args...) y, back = forward(f, args...)
losscheck(y) losscheck(y)
return back(1) return back(1)
end end
derivative(f, x) = gradient(f, x)[1] gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
# Non-nesting versions gradient(f, ps::Params) = gradient_nested(f, ps)
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end

View File

@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}() IdSet() = IdSet{Any}()
Base.push!(s::IdSet) = s
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x) Base.in(x, s::IdSet) = haskey(s.dict, x)

View File

@ -33,8 +33,18 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
print(io, "TrackedArray{…,$A}")
Base.convert(::Type{<:TrackedArray}, x::TrackedArray) =
error("Not implemented: convert $(typeof(x)) to $T")
Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
TrackedArray(convert(A, x))
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
@isdefined(A) ?
print(io, "TrackedArray{…,$A}") :
invoke(show, Tuple{IO,DataType}, io, t)
function Base.summary(io::IO, x::TrackedArray) function Base.summary(io::IO, x::TrackedArray)
print(io, "Tracked ") print(io, "Tracked ")
@ -43,11 +53,24 @@ end
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x)) Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
function Base.show(io::IO, x::TrackedArray)
show(io, data(x))
print(io, " (tracked)")
end
Base.copy(x::TrackedArray) = x
Base.setindex!(xs::TrackedArray, v, i...) = Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`") error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`") back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
function update!(x::TrackedArray, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
# Fallthrough methods # Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args for f in :[Base.size, Base.ndims, Base.collect].args
@ -80,6 +103,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
end end
end end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@grad function view(x::AbstractArray, inds...)
view(data(x), inds...), function (Δ)
grad_output = zero(x)
subgrad = view(grad_output, inds...)
subgrad[:] = data(Δ)
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs) Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,) @grad -(xs) = -data(xs), Δ -> (-Δ,)
@ -87,8 +121,8 @@ Base.:-(xs::TrackedArray) = track(-, xs)
Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs) Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) @grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) @grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
@ -108,30 +142,28 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
end end
end end
for f in [:vcat, :hcat] function combinations(xs, n)
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose}) n < 1 && return [[]]
@eval begin cs = combinations(xs, n-1)
# This section is a bit of a hack since julia doesn't have a standardised [[x, c...] for x in xs, c in cs]
# promotion mechanism for concatenation yet end
# https://github.com/JuliaLang/julia/pull/20815
# It should support tracked concatenation with rank ∈ (1,2) with a for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
# TrackedArray anywhere among the arguments This works as long as base has cnames = map(_ -> gensym(), c)
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
Base.$f(a::$UArray...) = track($f, a...) track($f, $(cnames...), x, xs...)
end
# It should support tracked concatenation with rank>2 if the TrackedArray is for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
# first cnames = map(_ -> gensym(), c)
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row track($f, $(cnames...), x, xs...)
end
# It should support tracked concatenation with rank>2 if the TrackedArray is for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
# second cnames = map(_ -> gensym(), c)
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray, track($f, $(cnames...), x, xs...)
c::$UArray...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row
end
end end
@grad function vcat(xs...) @grad function vcat(xs...)
@ -164,10 +196,11 @@ end
end end
end end
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims) for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) cnames = map(_ -> gensym(), c)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) @eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) track(cat, $(cnames...), x, xs..., dims = dims)
end
@grad function cat(Xs...; dims) @grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ) cat(data.(Xs)..., dims = dims), function (Δ)
@ -307,8 +340,8 @@ end
# BLAS # BLAS
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x) LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) @grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y) x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
@ -328,7 +361,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
# NNlib # NNlib
using NNlib using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs) softmax(xs::TrackedArray) = track(softmax, xs)
@ -338,6 +371,16 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) @grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
@grad depthwiseconv(x, w; kw...) =
depthwiseconv(data(x), data(w); kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@ -374,8 +417,7 @@ unbroadcast(x::AbstractArray, Δ) =
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Number, Δ) = sum(Δ) unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing unbroadcast(x::Base.RefValue, _) = nothing
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
dual(x, p) = x dual(x, p) = x
dual(x::Real, p) = Dual(x, p) dual(x::Real, p) = Dual(x, p)
@ -423,26 +465,28 @@ end
using Requires using Requires
# https://github.com/FluxML/Flux.jl/issues/353 # https://github.com/FluxML/Flux.jl/issues/353
@init Requires.isprecompiling() || @eval Base.Broadcast begin if VERSION < v"1.1.0-DEV.548"
function flatten(bc::Broadcasted{Style}) where {Style} @init Requires.isprecompiling() || @eval Base.Broadcast begin
isflat(bc) && return bc function flatten(bc::Broadcasted{Style}) where {Style}
args = cat_nested(bc) isflat(bc) && return bc
let makeargs = make_makeargs(bc), f = bc.f args = cat_nested(bc)
newf = @inline function(args::Vararg{Any,N}) where N let makeargs = make_makeargs(bc), f = bc.f
f(makeargs(args...)...) newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
end end
return Broadcasted{Style}(newf, args, bc.axes)
end end
end @inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}}) bc = t[1]
bc = t[1] let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f let makeargs = make_makeargs(makeargs, bc.args)
let makeargs = make_makeargs(makeargs, bc.args) headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) return @inline function(args::Vararg{Any,N}) where N
return @inline function(args::Vararg{Any,N}) where N args1 = makeargs(args...)
args1 = makeargs(args...) a, b = headargs(args1...), tailargs(args1...)
a, b = headargs(args1...), tailargs(args1...) (f(a...), b...)
(f(a...), b...) end
end end
end end
end end

View File

@ -1,4 +1,4 @@
struct TrackedReal{T<:Real} <: Real mutable struct TrackedReal{T<:Real} <: Real
data::T data::T
tracker::Tracked{T} tracker::Tracked{T}
end end
@ -10,19 +10,28 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal) function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf") isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN") isnan(x) && error("Loss is NaN")
return back!(x, 1) return back!(x, 1, once = once)
end
function update!(x::TrackedReal, Δ)
x.data += data(Δ)
tracker(x).grad = 0
return x
end end
function Base.show(io::IO, x::TrackedReal) function Base.show(io::IO, x::TrackedReal)
T = get(io, :typeinfo, Any)
show(io, data(x)) show(io, data(x))
print(io, " (tracked)") T <: TrackedReal || print(io, " (tracked)")
end end
Base.decompose(x::TrackedReal) = Base.decompose(data(x)) Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.copy(x::TrackedReal) = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
@ -30,23 +39,32 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T") error("Not implemented: convert tracked $S to tracked $T")
for op in [:(==), :≈, :<] for op in [:(==), :≈, :<, :(<=)]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y) @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y)) @eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end end
Base.eps(x::TrackedReal) = eps(data(x)) Base.eps(x::TrackedReal) = eps(data(x))
Base.eps(::Type{TrackedReal{T}}) where T = eps(T)
for f in :[isinf, isnan, isfinite].args for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x)) @eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end end
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n) Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
Base.float(x::TrackedReal) = x
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)} TrackedReal{promote_type(S,T)}
using Random
for f in :[rand, randn, randexp].args
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
end
using DiffRules, SpecialFunctions, NaNMath using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
@ -58,12 +76,18 @@ for (M, f, arity) in DiffRules.diffrules()
end end
end end
# Work around zero(π) not working, for some reason
_zero(::Irrational) = nothing
_zero(x) = zero(x)
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :a, :b) da, db = DiffRules.diffrule(M, f, :a, :b)
f = :($M.$f) f = :($M.$f)
@eval begin @eval begin
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b)
$f(a::Real, b::TrackedReal) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b)
@ -75,6 +99,12 @@ import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b) ^(a::TrackedReal, b::Integer) = track(^, a, b)
# Hack for conversions
using ForwardDiff: Dual
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
# Tuples # Tuples
struct TrackedTuple{T<:Tuple} struct TrackedTuple{T<:Tuple}
@ -115,8 +145,8 @@ function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1]) foreach(scan, c.args[1])
end end
function back_(c::Call{typeof(collect)}, Δ) function back_(c::Call{typeof(collect)}, Δ, once)
foreach(back, c.args[1], data(Δ)) foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end end
function back_(g::Grads, c::Call{typeof(collect)}, Δ) function back_(g::Grads, c::Call{typeof(collect)}, Δ)

View File

@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
end end
function params(m) function params(m)
ps = [] ps = Params()
prefor(p -> prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) && Tracker.istracked(p) && Tracker.isleaf(p) &&
!any(p -> p === p, ps) && push!(ps, p), !any(p -> p === p, ps) && push!(ps, p),

View File

@ -1,8 +1,12 @@
# Arrays # Arrays
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
initn(dims...) = randn(dims...)/100 ones(T::Type, dims...) = Base.ones(T, dims...)
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims))) zeros(T::Type, dims...) = Base.zeros(T, dims...)
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
ones(dims...) = Base.ones(Float32, dims...)
zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
@ -24,7 +28,7 @@ julia> chunk(1:10, 3)
""" """
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n))) chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i) batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
""" """
frequencies(xs) frequencies(xs)
@ -66,7 +70,7 @@ julia> batch([[1,2,3],[4,5,6]])
function batch(xs) function batch(xs)
data = first(xs) isa AbstractArray ? data = first(xs) isa AbstractArray ?
similar(first(xs), size(first(xs))..., length(xs)) : similar(first(xs), size(first(xs))..., length(xs)) :
Vector{eltype(xs)}(length(xs)) Vector{eltype(xs)}(undef, length(xs))
for (i, x) in enumerate(xs) for (i, x) in enumerate(xs)
data[batchindex(data, i)...] = x data[batchindex(data, i)...] = x
end end
@ -147,9 +151,24 @@ function jacobian(m,x)
n = length(x) n = length(x)
J = Matrix{eltype(x)}(undef,n,k) J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator xp.grad .= 0 # Reset gradient accumulator
end end
J' J'
end end
"""
@jit ...
The `@jit` annotation can be applied to any code, and the code will be compiled
for performance.
@jit f(x) = @jit(x) + @jit(x)
Note that compilation happens regardless of the `@jit` macro, so it should only
be used for aesthetic purposes, or by recovering Python users.
"""
macro jit(ex)
esc(ex)
end

View File

@ -11,6 +11,8 @@ x = param(randn(5, 5))
cx = gpu(x) cx = gpu(x)
@test cx isa TrackedArray && cx.data isa CuArray @test cx isa TrackedArray && cx.data isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3) x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x) cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray @test cx isa Flux.OneHotMatrix && cx.data isa CuArray
@ -36,4 +38,8 @@ Flux.back!(sum(l))
end end
CuArrays.cudnn_available() && include("cudnn.jl") if CuArrays.libcudnn != nothing
@info "Testing Flux/CUDNN"
include("cudnn.jl")
include("curnn.jl")
end

View File

@ -1,48 +1,48 @@
using Flux, CuArrays, Test using Flux, Flux.Tracker, CuArrays, Test
using Flux.Tracker: TrackedArray, data
@info "Testing Flux/CUDNN" @testset "CUDNN BatchNorm" begin
@testset "4D Input" begin
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
@testset "RNN" begin y = m(x)
@testset for R in [RNN, GRU, LSTM] cy = cm(cx)
rnn = R(10, 5)
curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
@test y.data collect(cuy.data) @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
@test haskey(Flux.CUDA.descs, curnn.cell)
Δ = randn(size(y)) @test cpu(data(cy)) data(y)
Flux.back!(y, Δ) g = rand(size(y)...)
Flux.back!(cuy, gpu(Δ)) Flux.back!(y, g)
Flux.back!(cy, gpu(g))
@test x.grad collect(cux.grad) @test m.γ.grad cpu(cm.γ.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad) @test m.β.grad cpu(cm.β.grad)
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad) @test x.grad cpu(x.grad)
@test rnn.cell.b.grad collect(curnn.cell.b.grad) end
@test rnn.cell.h.grad collect(curnn.cell.h.grad)
if isdefined(rnn.cell, :c) @testset "2D Input" begin
@test rnn.cell.c.grad collect(curnn.cell.c.grad) x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
end m = BatchNorm(3)
cx = gpu(x)
Flux.reset!(rnn) cm = gpu(m)
Flux.reset!(curnn)
ohx = batch_size == 1 ? y = m(x)
Flux.onehot(rand(1:10), 1:10) : cy = cm(cx)
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx) @test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx)) @test cpu(data(cy)) data(y)
@test y.data collect(cuy.data) g = rand(size(y)...)
Flux.back!(y, g)
Flux.back!(cy, gpu(g))
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end end
end
end end

46
test/cuda/curnn.jl Normal file
View File

@ -0,0 +1,46 @@
using Flux, CuArrays, Test
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
rnn = R(10, 5)
curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
@test y.data collect(cuy.data)
@test haskey(Flux.CUDA.descs, curnn.cell)
Δ = randn(size(y))
Flux.back!(y, Δ)
Flux.back!(cuy, gpu(Δ))
@test x.grad collect(cux.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad)
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad)
@test rnn.cell.b.grad collect(curnn.cell.b.grad)
@test rnn.cell.h.grad collect(curnn.cell.h.grad)
if isdefined(rnn.cell, :c)
@test rnn.cell.c.grad collect(curnn.cell.c.grad)
end
Flux.reset!(rnn)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y.data collect(cuy.data)
end
end
end

View File

@ -10,4 +10,7 @@ using Test
@test MNIST.images()[1] isa Matrix @test MNIST.images()[1] isa Matrix
@test MNIST.labels() isa Vector{Int64} @test MNIST.labels() isa Vector{Int64}
@test FashionMNIST.images()[1] isa Matrix
@test FashionMNIST.labels() isa Vector{Int64}
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}} @test Data.Sentiment.train() isa Vector{Data.Tree{Any}}

33
test/layers/basic.jl Normal file
View File

@ -0,0 +1,33 @@
using Test, Random
@testset "basic" begin
@testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
# numeric test should be put into testset of corresponding layer
end
@testset "Dense" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end
@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end
end

View File

@ -2,7 +2,7 @@ using Flux, Test
using Flux: maxpool, meanpool using Flux: maxpool, meanpool
@testset "Pooling" begin @testset "Pooling" begin
x = randn(10, 10, 3, 2) x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2)) mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2)) @test mp(x) == maxpool(x, (2,2))
mp = MeanPool((2, 2)) mp = MeanPool((2, 2))
@ -10,7 +10,7 @@ using Flux: maxpool, meanpool
end end
@testset "CNN" begin @testset "CNN" begin
r = zeros(28, 28, 1, 5) r = zeros(Float32, 28, 28, 1, 5)
m = Chain( m = Chain(
Conv((2, 2), 1=>16, relu), Conv((2, 2), 1=>16, relu),
MaxPool((2,2)), MaxPool((2,2)),

View File

@ -1,4 +1,5 @@
using Flux: testmode! using Flux: testmode!
using Flux.Tracker: data
@testset "Dropout" begin @testset "Dropout" begin
x = [1.,2.,3.] x = [1.,2.,3.]
@ -28,7 +29,8 @@ using Flux: testmode!
end end
@testset "BatchNorm" begin @testset "BatchNorm" begin
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]') let m = BatchNorm(2), x = param([1 3 5;
2 4 6])
@test m.β.data == [0, 0] # initβ(2) @test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2) @test m.γ.data == [1, 1] # initγ(2)
@ -53,29 +55,30 @@ end
# .1 * 4 + 0 = .4 # .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1) @test m.μ reshape([0.3, 0.4], 2, 1)
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 2×1 Array{Float64,2}: # 2×1 Array{Float64,2}:
# 1.14495 # 1.3
# 1.14495 # 1.3
@test m.σ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] @test m.σ² .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m) testmode!(m)
@test !m.active @test !m.active
x = m(x).data x = m(x).data
@test x[1] (1 .- 0.3) / 1.1449489742783179 @test isapprox(x[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
end end
# with activation function # with activation function
let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]') let m = BatchNorm(2, sigmoid), x = param([1 3 5;
2 4 6])
@test m.active @test m.active
m(x) m(x)
testmode!(m) testmode!(m)
@test !m.active @test !m.active
x = m(x).data y = m(x).data
@test x[1] σ((1 - 0.3) / 1.1449489742783179) @test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
end end
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1)) let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
@ -85,7 +88,7 @@ end
end end
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1)) let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y @test m(x) == y
end end

View File

@ -49,4 +49,16 @@ const ϵ = 1e-7
@testset "logitbinarycrossentropy" begin @testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0) @test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end end
@testset "no spurious promotions" begin
for T in (Float16, Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.Tracker.forward(mse, , y)
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
end
end
end
end end

View File

@ -1,16 +1,40 @@
using Flux.Optimise using Flux.Optimise
using Flux.Optimise: runall
using Flux.Tracker using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] @testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w]) opt = Opt(0.001)
for t=1:10^5 if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
opt() delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
end
end
@testset "Optimiser" begin
w = randn(10, 10)
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Optimiser(Opt(), ADAM(0.001))
for t = 1:10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end
@ -21,9 +45,17 @@ end
l = param(1) l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
(),
Iterators.repeated((), 100), Iterators.repeated((), 100),
()->(), Descent(),
cb = Flux.throttle(() -> (i > 3 && stop()), 1)) cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
@test 3 < i < 50 @test 3 < i < 50
# Test multiple callbacks
x = 0
fs = [() -> (), () -> x = 1]
cbs = runall(fs)
cbs()
@test x == 1
end end

View File

@ -1,19 +1,4 @@
# Pkg.test runs with --check_bounds=1, forcing all bounds checks. using Flux, Test, Random, Statistics
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
if Base.JLOptions().check_bounds == 1
file = @__FILE__
run(```
$(Base.julia_cmd())
--color=$(Base.have_color ? "yes" : "no")
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
$(file)
```)
exit()
end
using Flux, Test, Random
using Random using Random
Random.seed!(0) Random.seed!(0)
@ -32,6 +17,7 @@ include("data.jl")
@info "Testing Layers" @info "Testing Layers"
include("layers/basic.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/conv.jl") include("layers/conv.jl")

View File

@ -1,9 +1,9 @@
using Flux using Flux
using Flux.Tracker, Test, NNlib using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
using NNlib: conv using NNlib: conv, depthwiseconv
using Printf: @sprintf using Printf: @sprintf
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std using Statistics: mean, std
using Random using Random
# using StatsBase # using StatsBase
@ -33,16 +33,16 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5)) @test gradtest(x -> x', rand(5))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end
function promotiontest(f, A, B, C) function promotiontest(f, A, B, C)
r0 = f(A, B, C) r0 = f(A, B, C)
r1 = f(param(A), B, C) r1 = f(param(A), B, C)
r2 = f(A, param(B), C) r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat] r3 = f(A, B, param(C))
r3 = f(A, B, param(C))
else
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end
r4 = f(param(A), param(B), param(C)) r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray) @test !isa(r0, TrackedArray)
@ -127,7 +127,7 @@ end
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(f-> Matrix(Diagonal(f)), rand(3)) @test gradtest(x -> diagm(0 => x), rand(3))
@test gradtest(W -> inv(log.(W * W)), (5,5)) @test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5)) @test gradtest((A, B) -> A / B , (1,5), (5,5))
@ -181,12 +181,16 @@ end
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) @test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test gradtest(x -> Float64.(x), 5)
@testset "equality & order" begin @testset "equality & order" begin
# TrackedReal # TrackedReal
@test param(2)^2 == param(4) @test param(2)^2 == param(4)
@ -230,10 +234,10 @@ end
@testset "Intermediates" begin @testset "Intermediates" begin
x = param([1]) x = param([1])
l = sum((x .+ x).^2) l = sum((x .+ x).^2)
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
x.grad .= 0 x.grad .= 0
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
end end
@ -258,7 +262,7 @@ Tracker.back!(b)
back!(z) back!(z)
@test grad.((x,y)) == (3, 2) @test grad.((x,y)) == (3, 2)
@test Tracker.gradient(2, 3) do x, y @test gradient(2, 3) do x, y
xy = Tracker.collect([x, y]) xy = Tracker.collect([x, y])
xy[1]*xy[2] xy[1]*xy[2]
end == (3, 2) end == (3, 2)
@ -278,10 +282,27 @@ end
count += 1 count += 1
a * b a * b
end end
@test derivative(x -> mul(5, x), 3) == 5 @test gradient(x -> mul(5, x), 3)[1] == 5
@test count == 1 @test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5 @test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
@test count == 3 @test count == 3
end end
@testset "Updates" begin
xs = param([1, 2, 3])
Tracker.update!(xs, param([4, 5, 6]))
@test xs == [5, 7, 9]
x = param(3)
Tracker.update!(x, param(4))
@test x == 7
end
@testset "Params" begin
W = param(randn(5, 10))
x = rand(10)
dW = gradient(W -> sum(W*x), W)[1]
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
@test gs[W] == dW
end
end #testset end #testset

View File

@ -1,5 +1,5 @@
using Flux using Flux
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal using Flux: throttle, jacobian, glorot_uniform, glorot_normal
using StatsBase: std using StatsBase: std
using Random using Random
using Test using Test
@ -64,10 +64,6 @@ end
@testset "Initialization" begin @testset "Initialization" begin
# Set random seed so that these tests don't fail randomly # Set random seed so that these tests don't fail randomly
Random.seed!(0) Random.seed!(0)
# initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10)
@test std(v) > 0.9*1e-2
@test std(v) < 1.1*1e-2
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)), # glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out) # and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)