This commit is contained in:
Avik Pal 2019-01-24 18:42:28 +05:30
commit 73c1485927
51 changed files with 2391 additions and 855 deletions

2
.gitignore vendored
View File

@ -3,6 +3,4 @@
*.jl.mem
docs/build/
docs/site/
docs/flux.css
deps
Manifest.toml

View File

@ -1,19 +1,25 @@
# Documentation: http://docs.travis-ci.com/user/languages/julia/
language: julia
os:
- linux
# - osx
julia:
- 0.7
- 1.0
- nightly
# uncomment the following lines to override the default test script
# script:
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
matrix:
allow_failures:
- julia: nightly
after_success:
- julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")'
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
jobs:
include:
- stage: "Documentation"
julia: 1.0
os: linux
script:
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
Pkg.instantiate()'
- julia --project=docs/ docs/make.jl
after_success: skip

272
Manifest.toml Normal file
View File

@ -0,0 +1,272 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.2"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.2"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.4"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.0"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"

25
Project.toml Normal file
View File

@ -0,0 +1,25 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"

View File

@ -1,9 +1,9 @@
julia 0.7
julia 1.0
Juno
MacroTools 0.3.3
NNlib
Requires
Adapt
Adapt 0.4
CodecZlib
Colors
ZipFile

288
docs/Manifest.toml Normal file
View File

@ -0,0 +1,288 @@
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.6.0"
[[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.21.0"
[[FixedPointNumbers]]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
path = ".."
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.6.10+"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"

4
docs/Project.toml Normal file
View File

@ -0,0 +1,4 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

View File

@ -2,10 +2,11 @@ using Documenter, Flux, NNlib
makedocs(modules=[Flux, NNlib],
doctest = false,
format = :html,
analytics = "UA-36890222-9",
sitename = "Flux",
assets = ["../flux.css"],
# Uncomment below for local build
#format = Documenter.HTML(prettyurls = false),
assets = ["assets/flux.css"],
pages = ["Home" => "index.md",
"Building Models" =>
["Basics" => "models/basics.md",
@ -22,10 +23,4 @@ makedocs(modules=[Flux, NNlib],
["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"])
deploydocs(
repo = "github.com/FluxML/Flux.jl.git",
target = "build",
osname = "linux",
julia = "1.0",
deps = nothing,
make = nothing)
deploydocs(repo = "github.com/FluxML/Flux.jl.git")

113
docs/src/assets/flux.css Normal file
View File

@ -0,0 +1,113 @@
@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');
body {
font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
}
nav.toc {
padding-top: 0;
background: rgb(240, 240, 240);
line-height: 2em;
cursor: default;
user-select: none;
}
h1+h2 {
margin-top: 0;
}
/* Green banner in ToC */
nav.toc > h1 {
margin-top: 0;
padding-top: 0.4em;
padding-bottom: 0.5em;
border-bottom: 5px solid white;
box-shadow: 0px -2px 5px rgb(60,60,60);
margin-bottom: 0.5em;
background: rgb(60, 150, 60);
font-style: italic;
font-weight: normal;
font-size: 50pt;
text-transform: lowercase;
text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
color: white;
}
/* Reduce ToC font size */
.toctext {
font-size: 10pt;
}
/* Fade out non-clickable ToC headers */
nav.toc ul span.toctext {
color: rgb(180, 180, 180);
}
nav.toc ul .toctext {
color: rgb(100, 100, 100);
}
nav.toc ul a.toctext:hover {
color: inherit;
background: rgb(220, 220, 220);
cursor: default;
}
nav.toc li.current > .toctext {
background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
font-weight: normal;
}
nav.toc ul.internal li.toplevel {
font-weight: normal;
}
/* Content */
article { max-width: none; }
article > p, article > ul {
max-width: 45em;
}
/* Links */
a, a:visited { color: rgb(0, 120, 0); }
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
a:hover, a:visited:hover { color: rgb(0, 80, 0); }
/* Article Links */
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }
/* Doctstrings */
article section.docstring {
padding: 0.5em 0;
border-left: none;
border-right: none;
border-bottom: none;
}
/* Code */
article pre, article p > code {
background: rgb(245, 250, 245);
}
article pre {
border: none;
max-width: none;
padding: 1em;
border-radius: 10px 0px 0px 10px;
margin-left: -1em;
margin-right: -2em;
}
.hljs-comment {
font-style: italic;
}
.hljs-number {
color: rgb(0, 150, 150);
}

View File

@ -4,7 +4,7 @@ Support for array operations on other hardware backends, like GPUs, is provided
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
(Note that you need to have CUDA available to use CuArrays please see the [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl) instructions for more details.)
```julia
using CuArrays

View File

@ -100,16 +100,16 @@ minus(a, b) = a - b
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
using Flux.Tracker: TrackedReal, track, @grad
using Flux.Tracker: TrackedArray, track, @grad
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia
@grad function minus(a, b)
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
end
```
@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
```
We can then calculate the first derivative of `minus` as follows:
```julia
a = param([1,2,3])
b = param([3,2,1])
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
Tracker.back!(c, 1)
Tracker.grad(a) # [1.00, 1.00, 1.00]
Tracker.grad(b) # [-1.00, -1.00, -1.00]
```
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia

View File

@ -28,7 +28,7 @@ When a function has many parameters, we can pass them all in explicitly:
f(W, b, x) = W * x + b
Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0, 2.0 (tracked))
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
@ -102,6 +102,8 @@ All deep learning in Flux, however complex, is a simple generalisation of this e
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as:
```julia
using Flux
W1 = param(rand(3, 5))
b1 = param(rand(3))
layer1(x) = W1 * x .+ b1

View File

@ -10,6 +10,12 @@ MaxPool
MeanPool
```
## Additional Convolution Layers
```@docs
DepthwiseConv
```
## Recurrent Layers
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).

View File

@ -23,44 +23,30 @@ We want to update each parameter, using the gradient, in order to improve (reduc
```julia
using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate
for p in (W, b)
update!(p, -η * grads[p])
end
η = 0.1 # Learning Rate
for p in (W, b)
update!(p, -η * grads[p])
end
```
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
In this case, getting the variables is trivial, but you can imagine it'd be more of a pain with some complex stack of layers.
Running this will alter the parameters `W` and `b` and our loss should go down. Flux provides a more general way to do optimiser updates like this.
```julia
m = Chain(
Dense(10, 5, σ),
Dense(5, 2), softmax)
opt = Descent(0.1) # Gradient descent with learning rate 0.1
for p in (W, b)
update!(opt, p, -η * grads[p])
end
```
Instead of having to write `[m[1].W, m[1].b, ...]`, Flux provides a params function `params(m)` that returns a list of all parameters in the model for you.
For the update step, there's nothing whatsoever wrong with writing the loop above it'll work just fine but Flux provides various *optimisers* that make it more convenient.
```julia
opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
opt() # Carry out the update, modifying `W` and `b`.
```
An optimiser takes a parameter list and returns a function that does the same thing as `update` above. We can pass either `opt` or `update` to our [training loop](training.md), which will then run the optimiser after every mini-batch of data.
An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `ADAM`.
## Optimiser Reference
All optimisers return a function that, when called, will update the parameters passed to it.
All optimisers return an object that, when passed to `train!`, will update the parameters passed to it.
```@docs
SGD
Descent
Momentum
Nesterov
ADAM

View File

@ -9,7 +9,7 @@ To actually train a model we need three things:
With these we can call `Flux.train!`:
```julia
Flux.train!(objective, data, opt)
Flux.train!(objective, params, data, opt)
```
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
@ -24,9 +24,10 @@ m = Chain(
Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y)
ps = Flux.params(m)
# later
Flux.train!(loss, data, opt)
Flux.train!(loss, ps, data, opt)
```
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
@ -78,7 +79,7 @@ julia> @epochs 2 Flux.train!(...)
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
```julia
train!(objective, data, opt, cb = () -> println("training"))
train!(objective, ps, data, opt, cb = () -> println("training"))
```
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
@ -89,6 +90,6 @@ A more typical callback might look like this:
test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
Flux.train!(objective, data, opt,
Flux.train!(objective, ps, data, opt,
cb = throttle(evalcb, 5))
```

View File

@ -2,11 +2,12 @@ module Flux
# Zero Flux Given
using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm,
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu
@reexport using NNlib
@ -19,8 +20,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay
include("utils.jl")
include("onehot.jl")

View File

@ -1,7 +1,37 @@
module CUDA
using ..CuArrays
using Pkg.TOML
CuArrays.cudnn_available() && include("cudnn.jl")
function version_check()
minor_version = 9
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0"))
if !(version.major == 0 && version.minor == minor_version)
@warn """
Flux is only supported with CuArrays v0.$minor_version.
Try running `] pin CuArrays@0.$minor_version`.
"""
end
end
version_check()
if !applicable(CuArray{UInt8}, undef, 1)
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
end
if CuArrays.libcudnn != nothing
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
include("curnn.jl")
include("cudnn.jl")
else
@warn("CUDNN is not installed, some functionality will not be available.")
end
end

View File

@ -1,6 +1,6 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
cudnnDataType, TensorDesc, FilterDesc
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
import ..Flux: data
using LinearAlgebra
mutable struct DropoutDesc
@ -14,335 +14,215 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL]
s = Csize_t[0]
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
desc,libcudnn_handle[],ρ,states,length(states),seed)
desc,handle(),ρ,states,length(states),seed)
finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
return desc
end
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const BATCHNORM_SPATIAL = 1
const BATCHNORM_ACTIVATION = 0
const BATCHNORM_MIN_EPS = 1e-5
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
(wx, wh), bias
mutable struct BNCache
mean
ivar
end
mutable struct RNNDesc{T}
mode::Int
input::Int
hidden::Int
params::CuVector{T}
weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T}
ptr::Ptr{Nothing}
BNCache() = BNCache(nothing, nothing)
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
# so reshape a 2D Tensor into 4D
batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
y = similar(x)
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
alpha = alpha, beta = beta, eps = eps, training = training)
y
end
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input)
size = Csize_t[0]
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL]
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum; cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
dims = _wsize(x)
if eps < BATCHNORM_MIN_EPS
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
eps = BATCHNORM_MIN_EPS
end
return rd
end
xd = TensorDesc(x)
yd = TensorDesc(y)
gd = TensorDesc(T, dims)
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[])
end
if training
const workspace = [CuVector{UInt8}(1)]
if cache !== nothing
mean = zeros(CuArray{T}, dims...)
ivar = ones(CuArray{T}, dims...)
else
mean = C_NULL
ivar = C_NULL
end
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[])
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing) where T
if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
yd, y,
gd, g, b,
momentum, running_mean, running_var,
eps, mean, ivar)
if cache !== nothing
cache.mean = mean
cache.ivar = ivar
end
else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Cdouble),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
yd, y,
gd, g, b,
running_mean, running_var,
eps)
end
end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, C_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, eps = T(1e-5), alpha = T(1),
beta = T(0), training = true) where T<:Union{Float32, Float64}
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
alpha = alpha, beta = beta, training = training)
(dg, db, dropdims(dx, dims = (1, 2)))
end
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
xdesc = xDesc(x)
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ho = similar(h)
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, eps = T(1e-5), alpha = T(1),
beta = T(0), training = true) where T<:Union{Float32, Float64}
dg = similar(g)
db = similar(b)
dx = similar(x)
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
(dg, db, dx)
end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum; cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
if training
xd = TensorDesc(x)
dyd = TensorDesc(dy)
dxd = TensorDesc(dx)
gd = TensorDesc(T, _wsize(x))
if cache !== nothing
mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache")
else
mean, ivar = C_NULL, C_NULL
end
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
if eps < BATCHNORM_MIN_EPS
eps = BATCHNORM_MIN_EPS
end
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
dy = dy_ isa Integer ? zero(y) : dy_
yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
cudnnRNNBackwardData(rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], reserve)
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zero(rnn.params)
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end
# Interface
import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
return
end
function RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)),
xd, x,
dyd, dy,
dxd, dx,
gd, g, dg, db,
eps, mean, ivar)
else
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
dx .= dy .* reshape(g, _wsize(x)) .* ivar
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
end
end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
transpose(dWi), transpose(dWh), db))
end
end
# Flux Interface
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)

325
src/cuda/curnn.jl Normal file
View File

@ -0,0 +1,325 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
(wx, wh), bias
end
mutable struct RNNDesc{T}
mode::Int
input::Int
hidden::Int
params::CuVector{T}
weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T}
ptr::Ptr{Nothing}
end
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input)
size = Csize_t[0]
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL]
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
return rd
end
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
handle(), r, seqlen, xdesc, size)
return Int(size[])
end
const workspace = [CuVector{UInt8}(undef, 1)]
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(undef, bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
handle(), r, seqlen, xdesc, size)
return Int(size[])
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing) where T
if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
end
end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, C_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
xdesc = xDesc(x)
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ho = similar(h)
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
dy = dy_ isa Integer ? zero(y) : dy_
yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
cudnnRNNBackwardData(rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], reserve)
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zero(rnn.params)
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end
# Interface
import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
return
end
function RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
end
end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
transpose(dWi), transpose(dWh), db))
end
end

View File

@ -13,6 +13,9 @@ end
include("mnist.jl")
export MNIST
include("fashion-mnist.jl")
export FashionMNIST
include("cmudict.jl")
using .CMUDict

64
src/data/fashion-mnist.jl Normal file
View File

@ -0,0 +1,64 @@
module FashionMNIST
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
function load()
mkpath(dir)
cd(dir) do
for file in ["train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte"]
isfile(file) && continue
@info "Downloading Fashion-MNIST dataset"
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz")
open(file, "w") do io
write(io, gzopen(read, "$file.gz"))
end
end
end
end
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
"""
images()
images(:test)
Load the Fashion-MNIST images.
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
Returns the 60,000 training images by default; pass `:test` to retreive the
10,000 test images.
"""
function images(set = :train)
load()
io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES))
_, N, nrows, ncols = imageheader(io)
[rawimage(io) for _ in 1:N]
end
"""
labels()
labels(:test)
Load the labels corresponding to each of the images returned from `images()`.
Each label is a number from 0-9.
Returns the 60,000 training labels by default; pass `:test` to retreive the
10,000 test labels.
"""
function labels(set = :train)
load()
io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS))
_, N = labelheader(io)
[rawlabel(io) for _ = 1:N]
end
end

View File

@ -16,19 +16,21 @@ m(x) == m[2](m[1](x))
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers.
"""
struct Chain
layers::Vector{Any}
Chain(xs...) = new([xs...])
struct Chain{T<:Tuple}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
@forward Chain.layers Base.iterate
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
(c::Chain)(x) = applychain(c.layers, x)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
@ -75,7 +77,7 @@ end
@treelike Dense
function (a::Dense)(x)
function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
end
@ -114,3 +116,11 @@ end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

View File

@ -1,4 +1,4 @@
using NNlib: conv
using NNlib: conv, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
@ -30,14 +30,14 @@ Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N =
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike Conv
function (c::Conv)(x)
function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@ -51,6 +51,62 @@ function Base.show(io::IO, l::Conv)
print(io, ")")
end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
DepthwiseConv(size, in)
DepthwiseConv(size, in=>mul)
DepthwiseConv(size, in=>mul, relu)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `mul` specify the number of input channels and channel multiplier respectively.
In case the `mul` is not specified it is taken as 1.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct DepthwiseConv{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
end
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
stride = 1, pad = 0) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
stride = stride, pad = pad)
@treelike DepthwiseConv
function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
end
function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
MaxPool(k)
@ -60,9 +116,9 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
Takes the keyword arguments `pad` and `stride`.
"""
struct MaxPool{N}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =

View File

@ -44,7 +44,6 @@ end
_testmode!(a::Dropout, test) = (a.active = !test)
"""
LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
@ -86,7 +85,6 @@ See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
Example:
```julia
m = Chain(
Dense(28^2, 64),
@ -101,14 +99,14 @@ mutable struct BatchNorm{F,V,W,N}
β::V # bias
γ::V # scale
μ::W # moving mean
σ::W # moving std
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
@ -124,31 +122,31 @@ function (BN::BatchNorm)(x)
if !BN.active
μ = reshape(BN.μ, affine_shape...)
σ = reshape(BN.σ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...)
else
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes)
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
end
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
end
end
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)

View File

@ -148,7 +148,7 @@ Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
"""
LSTM(in::Integer, out::Integer, σ = tanh)
LSTM(in::Integer, out::Integer)
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
@ -189,7 +189,7 @@ Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
"""
GRU(in::Integer, out::Integer, σ = tanh)
GRU(in::Integer, out::Integer)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.

View File

@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
mse(, y) = sum(( .- y).^2) * 1 // length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2)
-sum(y .* log.() .* weight) * 1 // size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
end
"""

View File

@ -28,9 +28,9 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import Adapt.adapt
import Adapt: adapt, adapt_structure
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
import .CuArrays: CuArray, cudaconvert
@ -68,3 +68,6 @@ end
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)

View File

@ -1,23 +1,12 @@
module Optimise
export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
struct Param{T}
x::T
Δ::T
end
Param(x::AbstractArray) = Param(x, zero(x))
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl")
include("interface.jl")
include("train.jl")
using Flux.Tracker: TrackedArray
Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
include("deprecations.jl")
end

View File

@ -0,0 +1,126 @@
using Base: depwarn
using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
ps = params
opt = Descent(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
opt = Momentum(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
opt = Nesterov(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
opt = RMSProp(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
β = (β1, β2)
opt = ADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
opt = ADAGrad(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
opt = ADADelta(ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
β = (β1, β2)
opt = AdaMax(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
β = (β1, β2)
opt = AMSGrad(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
β = (β1, β2)
opt = NADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params
β = (β1, β2)
opt = ADAMW(η, β)
opt = check_decay(opt, decay)
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
updaterule(opt, ps)
end
# Old training loop
struct OldOptimiser
func
end
update!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
train!(loss, (), data, OldOptimiser(opt); cb = cb)
end

View File

@ -1,110 +0,0 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weights update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
os = map(f -> f(p), fs)
() -> foreach(call, os)
end
() -> foreach(call, fs)
end
"""
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -1,130 +1,327 @@
function descent(p::Param, η::Real)
function ()
@. p.x -= η * p.Δ
@. p.Δ = 0
using Flux
using Base: @get!
using MacroTools: @forward
const ϵ = 1e-8
# TODO: should use weak refs
"""
Descent(η)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
"""
mutable struct Descent
eta::Float64
end
Descent() = Descent(0.1)
function update!(o::Descent, x, Δ)
Δ .*= o.eta
end
"""
Momentum(params, η = 0.01; ρ = 0.9)
Gradient descent with learning rate `η` and momentum `ρ`.
"""
mutable struct Momentum
eta::Float64
rho::Float64
velocity::IdDict
end
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ
@. Δ = -v
end
"""
Nesterov(eta, ρ = 0.9)
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
"""
mutable struct Nesterov
eta::Float64
rho::Float64
velocity::IdDict
end
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
end
"""
RMSProp(η = 0.001, ρ = 0.9)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
mutable struct RMSProp
eta::Float64
rho::Float64
acc::IdDict
end
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function update!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ)
end
"""
ADAM(η = 0.001, β = (0.9, 0.999))
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
mutable struct ADAM
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
function update!(o::ADAM, x, Δ)
η, β = o.eta, o.beta
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β)
return Δ
end
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
mutable struct AdaMax
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. ut = max(β[2] * ut, abs(Δ))
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
o.state[x] = (mt, ut, βp .* β)
return Δ
end
"""
ADAGrad(η = 0.1; ϵ = 1e-8)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
mutable struct ADAGrad
eta::Float64
acc::IdDict
end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2
@. Δ *= η / (acc + ϵ)
end
"""
ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
mutable struct ADADelta
rho::Float64
state::IdDict
end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= Δacc/ (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
return Δ
end
"""
AMSGrad(η = 0.001, β = (0.9, 0.999))
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
mutable struct AMSGrad
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. Δ = η * mt / (v̂t + ϵ)
end
"""
NADAM(η = 0.001, β = (0.9, 0.999))
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
tuning.
"""
mutable struct NADAM
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ)
η, β = o.eta, o.beta
β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x)))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
return Δ
end
"""
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(decay))
# Compose optimizers
"""
Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
usual.
"""
mutable struct Optimiser
os::Vector{Any}
end
Optimiser(o...) = Optimiser(Any[o...])
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Optimiser.os Base.iterate
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
end
return Δ
end
# Ref: https://arxiv.org/abs/1711.05101.pdf
function descentweightdecay(p::Param, η::Real, γ::Real)
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
"""
`InvDecay(γ)`
Apply inverse time decay to an optimiser
```julia
Optimiser(InvDecay(..), Opt(..))
```
"""
mutable struct InvDecay
gamma::Float64
state::IdDict
end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ)
γ = o.gamma
n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n)
o.state[x] = n + 1
return Δ
end
"""
`ExpDecay(eta, decay, decay_step, clip)`
Schedule the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
To apply exponential decay to an optimiser:
```julia
Optimiser(ExpDecay(..), Opt(..))
```
"""
mutable struct ExpDecay
eta::Float64
decay::Float64
step::Int64
clip::Float64
current::IdDict
end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip)
o.eta = η
end
@. Δ *= decay
end
function momentum(p::Param, ρ, η)
v = zero(p.x)
function ()
@. v = ρ * v - η * p.Δ
@. p.Δ = -v
end
"""
`WeightDecay(wd)`
Decay the weight parameter by `wd`
"""
mutable struct WeightDecay
wd::Real
end
# Ref. https://arxiv.org/pdf/1212.0901.pdf
function nesterov(p::Param, ρ, η)
v = zero(p.x)
function ()
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ
@. p.Δ = -d
end
end
WeightDecay() = WeightDecay(0)
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zero(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
end
end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zero(p.x) .+ ϵ
function ()
@. acc += p.Δ^2
@. p.Δ *= η / (acc + ϵ)
end
end
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zero(p.x)
Δacc = zero(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end
end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
vt = zero(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
ut = zero(p.x)
β1p = β1
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. ut = max(β2 * ut, abs(p.Δ))
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
vt = zero(p.x) .+ ϵ
v̂t = zero(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zero(p.x)
vt = zero(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
function update!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end

View File

@ -1,7 +1,17 @@
using Juno
using Flux.Tracker: back!
using Flux.Tracker: data, grad, back!
import Base.depwarn
function update!(opt, xs)
for x in xs
Δ = update!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
end
end
# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -35,7 +45,7 @@ function stop()
end
"""
train!(loss, data, opt)
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
@ -44,22 +54,22 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
Flux.train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, data, opt; cb = () -> ())
function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
@progress for d in data
try
l = loss(d...)
@interrupts back!(l)
opt()
update!(opt, ps)
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break

View File

@ -5,7 +5,8 @@ using MacroTools: @q, @forward
import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back!
tracker(x) = nothing
@ -60,17 +61,11 @@ macro grad(ex)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
function update!(x, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
include("idset.jl")
include("back.jl")
include("scalar.jl")
include("array.jl")
include("numeric.jl")
include("lib/real.jl")
include("lib/array.jl")
"""
hook(f, x) -> x
@ -99,7 +94,8 @@ end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
@ -108,10 +104,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs))
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import NNlib.cudata
import Adapt.adapt
import Adapt: adapt, adapt_structure
cudata(x::TrackedArray) = data(x)
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end

View File

@ -19,62 +19,87 @@ function scan(x)
return
end
function back_(c::Call, Δ)
function back_(c::Call, Δ, once)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs))
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end
back_(::Call{Nothing}, Δ) = nothing
back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ)
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
else
x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
ref == 0 && back_(x.f, Δ)
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end
return
end
back(::Nothing, _) = return
back(::Nothing, Δ, once) = return
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x, Δ)
function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ)
back(tracker(x), Δ, once)
return
end
function gradient_(f, xs...)
xs = param.(data.(xs))
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end
# Out-of-place gradients
struct Params
params::IdSet
Params(xs) = new(IdSet(xs))
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.params Base.iterate, Base.length
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.params, ", ")
join(io, ps.order, ", ")
print(io, "])")
end
@ -91,12 +116,12 @@ Grads() = Grads(IdDict())
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
@ -146,20 +171,13 @@ function losscheck(x)
isnan(x) && error("Loss is NaN")
end
function gradient(f, args...)
function gradient_nested(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
derivative(f, x) = gradient(f, x)[1]
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
# Non-nesting versions
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end
gradient(f, ps::Params) = gradient_nested(f, ps)

View File

@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}()
Base.push!(s::IdSet) = s
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)

View File

@ -33,8 +33,18 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")
Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
Base.convert(::Type{<:TrackedArray}, x::TrackedArray) =
error("Not implemented: convert $(typeof(x)) to $T")
Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
TrackedArray(convert(A, x))
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
@isdefined(A) ?
print(io, "TrackedArray{…,$A}") :
invoke(show, Tuple{IO,DataType}, io, t)
function Base.summary(io::IO, x::TrackedArray)
print(io, "Tracked ")
@ -43,11 +53,24 @@ end
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
function Base.show(io::IO, x::TrackedArray)
show(io, data(x))
print(io, " (tracked)")
end
Base.copy(x::TrackedArray) = x
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
function update!(x::TrackedArray, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
# Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args
@ -80,6 +103,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
end
end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@grad function view(x::AbstractArray, inds...)
view(data(x), inds...), function (Δ)
grad_output = zero(x)
subgrad = view(grad_output, inds...)
subgrad[:] = data(Δ)
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,)
@ -87,8 +121,8 @@ Base.:-(xs::TrackedArray) = track(-, xs)
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
@ -108,30 +142,28 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
end
end
for f in [:vcat, :hcat]
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
@eval begin
# This section is a bit of a hack since julia doesn't have a standardised
# promotion mechanism for concatenation yet
# https://github.com/JuliaLang/julia/pull/20815
function combinations(xs, n)
n < 1 && return [[]]
cs = combinations(xs, n-1)
[[x, c...] for x in xs, c in cs]
end
# It should support tracked concatenation with rank ∈ (1,2) with a
# TrackedArray anywhere among the arguments This works as long as base has
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::$UArray...) = track($f, a...)
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
track($f, $(cnames...), x, xs...)
end
# It should support tracked concatenation with rank>2 if the TrackedArray is
# first
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
track($f, $(cnames...), x, xs...)
end
# It should support tracked concatenation with rank>2 if the TrackedArray is
# second
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
c::$UArray...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row
end
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
track($f, $(cnames...), x, xs...)
end
@grad function vcat(xs...)
@ -164,10 +196,11 @@ end
end
end
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
cnames = map(_ -> gensym(), c)
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
track(cat, $(cnames...), x, xs..., dims = dims)
end
@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
@ -307,8 +340,8 @@ end
# BLAS
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
@ -328,7 +361,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs)
@ -338,6 +371,16 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
@grad depthwiseconv(x, w; kw...) =
depthwiseconv(data(x), data(w); kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@ -374,8 +417,7 @@ unbroadcast(x::AbstractArray, Δ) =
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
unbroadcast(x::Base.RefValue, _) = nothing
dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
@ -423,26 +465,28 @@ end
using Requires
# https://github.com/FluxML/Flux.jl/issues/353
@init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
let makeargs = make_makeargs(bc), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
if VERSION < v"1.1.0-DEV.548"
@init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
let makeargs = make_makeargs(bc), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
end
return Broadcasted{Style}(newf, args, bc.axes)
end
end
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs(args...)
a, b = headargs(args1...), tailargs(args1...)
(f(a...), b...)
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs(args...)
a, b = headargs(args1...), tailargs(args1...)
(f(a...), b...)
end
end
end
end

View File

@ -1,4 +1,4 @@
struct TrackedReal{T<:Real} <: Real
mutable struct TrackedReal{T<:Real} <: Real
data::T
tracker::Tracked{T}
end
@ -10,19 +10,28 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal)
function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
return back!(x, 1, once = once)
end
function update!(x::TrackedReal, Δ)
x.data += data(Δ)
tracker(x).grad = 0
return x
end
function Base.show(io::IO, x::TrackedReal)
T = get(io, :typeinfo, Any)
show(io, data(x))
print(io, " (tracked)")
T <: TrackedReal || print(io, " (tracked)")
end
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.copy(x::TrackedReal) = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
@ -30,23 +39,32 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
for op in [:(==), :≈, :<]
for op in [:(==), :≈, :<, :(<=)]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
Base.eps(x::TrackedReal) = eps(data(x))
Base.eps(::Type{TrackedReal{T}}) where T = eps(T)
for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
Base.float(x::TrackedReal) = x
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)}
using Random
for f in :[rand, randn, randexp].args
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
end
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
@ -58,12 +76,18 @@ for (M, f, arity) in DiffRules.diffrules()
end
end
# Work around zero(π) not working, for some reason
_zero(::Irrational) = nothing
_zero(x) = zero(x)
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :a, :b)
f = :($M.$f)
@eval begin
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b)
$f(a::Real, b::TrackedReal) = track($f, a, b)
@ -75,6 +99,12 @@ import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b)
# Hack for conversions
using ForwardDiff: Dual
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
# Tuples
struct TrackedTuple{T<:Tuple}
@ -115,8 +145,8 @@ function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back_(c::Call{typeof(collect)}, Δ)
foreach(back, c.args[1], data(Δ))
function back_(c::Call{typeof(collect)}, Δ, once)
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)

View File

@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
end
function params(m)
ps = []
ps = Params()
prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) &&
!any(p -> p === p, ps) && push!(ps, p),

View File

@ -1,8 +1,12 @@
# Arrays
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
ones(T::Type, dims...) = Base.ones(T, dims...)
zeros(T::Type, dims...) = Base.zeros(T, dims...)
ones(dims...) = Base.ones(Float32, dims...)
zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
@ -24,7 +28,7 @@ julia> chunk(1:10, 3)
"""
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i)
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
"""
frequencies(xs)
@ -66,7 +70,7 @@ julia> batch([[1,2,3],[4,5,6]])
function batch(xs)
data = first(xs) isa AbstractArray ?
similar(first(xs), size(first(xs))..., length(xs)) :
Vector{eltype(xs)}(length(xs))
Vector{eltype(xs)}(undef, length(xs))
for (i, x) in enumerate(xs)
data[batchindex(data, i)...] = x
end
@ -147,9 +151,24 @@ function jacobian(m,x)
n = length(x)
J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator
xp.grad .= 0 # Reset gradient accumulator
end
J'
end
"""
@jit ...
The `@jit` annotation can be applied to any code, and the code will be compiled
for performance.
@jit f(x) = @jit(x) + @jit(x)
Note that compilation happens regardless of the `@jit` macro, so it should only
be used for aesthetic purposes, or by recovering Python users.
"""
macro jit(ex)
esc(ex)
end

View File

@ -11,6 +11,8 @@ x = param(randn(5, 5))
cx = gpu(x)
@test cx isa TrackedArray && cx.data isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
@ -36,4 +38,8 @@ Flux.back!(sum(l))
end
CuArrays.cudnn_available() && include("cudnn.jl")
if CuArrays.libcudnn != nothing
@info "Testing Flux/CUDNN"
include("cudnn.jl")
include("curnn.jl")
end

View File

@ -1,48 +1,48 @@
using Flux, CuArrays, Test
using Flux, Flux.Tracker, CuArrays, Test
using Flux.Tracker: TrackedArray, data
@info "Testing Flux/CUDNN"
@testset "CUDNN BatchNorm" begin
@testset "4D Input" begin
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
rnn = R(10, 5)
curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
y = m(x)
cy = cm(cx)
@test y.data collect(cuy.data)
@test haskey(Flux.CUDA.descs, curnn.cell)
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
Δ = randn(size(y))
@test cpu(data(cy)) data(y)
Flux.back!(y, Δ)
Flux.back!(cuy, gpu(Δ))
g = rand(size(y)...)
Flux.back!(y, g)
Flux.back!(cy, gpu(g))
@test x.grad collect(cux.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad)
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad)
@test rnn.cell.b.grad collect(curnn.cell.b.grad)
@test rnn.cell.h.grad collect(curnn.cell.h.grad)
if isdefined(rnn.cell, :c)
@test rnn.cell.c.grad collect(curnn.cell.c.grad)
end
Flux.reset!(rnn)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y.data collect(cuy.data)
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end
@testset "2D Input" begin
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
y = m(x)
cy = cm(cx)
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
@test cpu(data(cy)) data(y)
g = rand(size(y)...)
Flux.back!(y, g)
Flux.back!(cy, gpu(g))
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end
end
end

46
test/cuda/curnn.jl Normal file
View File

@ -0,0 +1,46 @@
using Flux, CuArrays, Test
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
rnn = R(10, 5)
curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
@test y.data collect(cuy.data)
@test haskey(Flux.CUDA.descs, curnn.cell)
Δ = randn(size(y))
Flux.back!(y, Δ)
Flux.back!(cuy, gpu(Δ))
@test x.grad collect(cux.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad)
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad)
@test rnn.cell.b.grad collect(curnn.cell.b.grad)
@test rnn.cell.h.grad collect(curnn.cell.h.grad)
if isdefined(rnn.cell, :c)
@test rnn.cell.c.grad collect(curnn.cell.c.grad)
end
Flux.reset!(rnn)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y.data collect(cuy.data)
end
end
end

View File

@ -10,4 +10,7 @@ using Test
@test MNIST.images()[1] isa Matrix
@test MNIST.labels() isa Vector{Int64}
@test FashionMNIST.images()[1] isa Matrix
@test FashionMNIST.labels() isa Vector{Int64}
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}

33
test/layers/basic.jl Normal file
View File

@ -0,0 +1,33 @@
using Test, Random
@testset "basic" begin
@testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
# numeric test should be put into testset of corresponding layer
end
@testset "Dense" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end
@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end
end

View File

@ -2,7 +2,7 @@ using Flux, Test
using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(10, 10, 3, 2)
x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
mp = MeanPool((2, 2))
@ -10,7 +10,7 @@ using Flux: maxpool, meanpool
end
@testset "CNN" begin
r = zeros(28, 28, 1, 5)
r = zeros(Float32, 28, 28, 1, 5)
m = Chain(
Conv((2, 2), 1=>16, relu),
MaxPool((2,2)),

View File

@ -1,4 +1,5 @@
using Flux: testmode!
using Flux.Tracker: data
@testset "Dropout" begin
x = [1.,2.,3.]
@ -28,7 +29,8 @@ using Flux: testmode!
end
@testset "BatchNorm" begin
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]')
let m = BatchNorm(2), x = param([1 3 5;
2 4 6])
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
@ -53,29 +55,30 @@ end
# .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1)
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 2×1 Array{Float64,2}:
# 1.14495
# 1.14495
@test m.σ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 1.3
# 1.3
@test m.σ² .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m)
@test !m.active
x = m(x).data
@test x[1] (1 .- 0.3) / 1.1449489742783179
@test isapprox(x[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
end
# with activation function
let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]')
let m = BatchNorm(2, sigmoid), x = param([1 3 5;
2 4 6])
@test m.active
m(x)
testmode!(m)
@test !m.active
x = m(x).data
@test x[1] σ((1 - 0.3) / 1.1449489742783179)
y = m(x).data
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
end
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
@ -85,7 +88,7 @@ end
end
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
end

View File

@ -49,4 +49,16 @@ const ϵ = 1e-7
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end
@testset "no spurious promotions" begin
for T in (Float16, Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.Tracker.forward(mse, , y)
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
end
end
end
end

View File

@ -1,16 +1,40 @@
using Flux.Optimise
using Flux.Optimise: runall
using Flux.Tracker
using Test
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])
for t=1:10^5
opt = Opt(0.001)
if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5
l = loss(rand(10))
back!(l)
opt()
delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
end
end
@testset "Optimiser" begin
w = randn(10, 10)
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Optimiser(Opt(), ADAM(0.001))
for t = 1:10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
end
@ -21,9 +45,17 @@ end
l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l),
(),
Iterators.repeated((), 100),
()->(),
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
Descent(),
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
@test 3 < i < 50
# Test multiple callbacks
x = 0
fs = [() -> (), () -> x = 1]
cbs = runall(fs)
cbs()
@test x == 1
end

View File

@ -1,19 +1,4 @@
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
if Base.JLOptions().check_bounds == 1
file = @__FILE__
run(```
$(Base.julia_cmd())
--color=$(Base.have_color ? "yes" : "no")
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
$(file)
```)
exit()
end
using Flux, Test, Random
using Flux, Test, Random, Statistics
using Random
Random.seed!(0)
@ -32,6 +17,7 @@ include("data.jl")
@info "Testing Layers"
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/conv.jl")

View File

@ -1,9 +1,9 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std
using Random
# using StatsBase
@ -33,16 +33,16 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat]
r3 = f(A, B, param(C))
else
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end
r3 = f(A, B, param(C))
r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray)
@ -127,7 +127,7 @@ end
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
@test gradtest(x -> diagm(0 => x), rand(3))
@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))
@ -181,12 +181,16 @@ end
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test gradtest(x -> Float64.(x), 5)
@testset "equality & order" begin
# TrackedReal
@test param(2)^2 == param(4)
@ -230,10 +234,10 @@ end
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l)
Flux.back!(l, once = false)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l)
Flux.back!(l, once = false)
@test x.grad == [8]
end
@ -258,7 +262,7 @@ Tracker.back!(b)
back!(z)
@test grad.((x,y)) == (3, 2)
@test Tracker.gradient(2, 3) do x, y
@test gradient(2, 3) do x, y
xy = Tracker.collect([x, y])
xy[1]*xy[2]
end == (3, 2)
@ -278,10 +282,27 @@ end
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test gradient(x -> mul(5, x), 3)[1] == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
@test count == 3
end
@testset "Updates" begin
xs = param([1, 2, 3])
Tracker.update!(xs, param([4, 5, 6]))
@test xs == [5, 7, 9]
x = param(3)
Tracker.update!(x, param(4))
@test x == 7
end
@testset "Params" begin
W = param(randn(5, 10))
x = rand(10)
dW = gradient(W -> sum(W*x), W)[1]
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
@test gs[W] == dW
end
end #testset

View File

@ -1,5 +1,5 @@
using Flux
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
using Flux: throttle, jacobian, glorot_uniform, glorot_normal
using StatsBase: std
using Random
using Test
@ -64,10 +64,6 @@ end
@testset "Initialization" begin
# Set random seed so that these tests don't fail randomly
Random.seed!(0)
# initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10)
@test std(v) > 0.9*1e-2
@test std(v) < 1.1*1e-2
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)