Merge branch 'master' into conv_transpose

This commit is contained in:
Tejan Karmali 2019-02-02 10:20:45 +05:30 committed by GitHub
commit e54df2de06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 834 additions and 220 deletions

1
.gitignore vendored
View File

@ -3,5 +3,4 @@
*.jl.mem *.jl.mem
docs/build/ docs/build/
docs/site/ docs/site/
docs/flux.css
deps deps

View File

@ -1,18 +1,29 @@
# Documentation: http://docs.travis-ci.com/user/languages/julia/ # Documentation: http://docs.travis-ci.com/user/languages/julia/
language: julia language: julia
os: os:
- linux - linux
# - osx # - osx
julia: julia:
- 1.0 - 1.0
- nightly - nightly
# uncomment the following lines to override the default test script
# script:
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
matrix: matrix:
allow_failures: allow_failures:
- julia: nightly - julia: nightly
after_success:
- julia -e 'using Pkg; ps=Pkg.PackageSpec(name="Documenter", version="0.19"); Pkg.add(ps); Pkg.pin(ps); Pkg.add("NNlib")' jobs:
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))' include:
- stage: "Documentation"
julia: 1.0
os: linux
script:
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
Pkg.instantiate()'
- julia --project=docs/ docs/make.jl
after_success: skip
## uncomment the following lines to override the default test script
script:
- julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()'

View File

@ -1,14 +1,16 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractTrees]] [[AbstractTrees]]
deps = ["Markdown", "Test"] deps = ["Markdown", "Test"]
git-tree-sha1 = "feb8b2c99359901e295443c9d0c7e711604acf39" git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.0" version = "0.2.1"
[[Adapt]] [[Adapt]]
deps = ["LinearAlgebra", "Test"] deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06" git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1" version = "0.4.2"
[[Base64]] [[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -21,9 +23,9 @@ version = "0.8.10"
[[BinaryProvider]] [[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"] deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "9930c1a6cd49d9fcd7218df6be417e6ae4f1468a" git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.2" version = "0.5.3"
[[CodecZlib]] [[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
@ -51,15 +53,15 @@ version = "0.2.0"
[[Compat]] [[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "2d9e14d19bad3f9ad5cc5e4cffabc3cfa59de825" git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.3.0" version = "1.4.0"
[[DataStructures]] [[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "8fc6e166e24fda04b2b648d4260cdad241788c54" git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.14.0" version = "0.15.0"
[[Dates]] [[Dates]]
deps = ["Printf"] deps = ["Printf"]
@ -77,12 +79,12 @@ version = "0.0.3"
[[DiffRules]] [[DiffRules]]
deps = ["Random", "Test"] deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad" git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7" version = "0.0.8"
[[Distributed]] [[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]] [[FixedPointNumbers]]
@ -93,19 +95,19 @@ version = "0.5.3"
[[ForwardDiff]] [[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e" git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
uuid = "f6369f11-7733-5829-9624-2563aa707210" uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1" version = "0.10.2"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]] [[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"] deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658" git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3" version = "0.5.4"
[[LibGit2]] [[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -138,18 +140,18 @@ version = "0.5.0"
[[Missings]] [[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233" git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1" version = "0.4.0"
[[Mmap]] [[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]] [[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "d7f65ad9734adea3c5a4c473bc65b365f8afbb2b" git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.2" version = "0.4.3"
[[NaNMath]] [[NaNMath]]
deps = ["Compat"] deps = ["Compat"]
@ -226,19 +228,19 @@ version = "0.7.2"
[[StaticArrays]] [[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "ebc5c2a27d91d5ec611a9861168182e2168effd3" git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182" uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.9.2" version = "0.10.2"
[[Statistics]] [[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"] deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]] [[StatsBase]]
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "723193a13e8078cec6dcd0b8fe245c8bfd81690e" git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.25.0" version = "0.27.0"
[[Test]] [[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
@ -257,14 +259,14 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0" version = "0.4.0"
[[UUIDs]] [[UUIDs]]
deps = ["Random"] deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]] [[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]] [[ZipFile]]
deps = ["Printf", "Test"] deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "c191e56c849b1784cacbf7cd5e52cc672f1ae2db" git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.7.0" version = "0.8.0"

View File

@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

288
docs/Manifest.toml Normal file
View File

@ -0,0 +1,288 @@
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.6.0"
[[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.21.0"
[[FixedPointNumbers]]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
path = ".."
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.6.10+"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"

4
docs/Project.toml Normal file
View File

@ -0,0 +1,4 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

View File

@ -2,10 +2,11 @@ using Documenter, Flux, NNlib
makedocs(modules=[Flux, NNlib], makedocs(modules=[Flux, NNlib],
doctest = false, doctest = false,
format = :html,
analytics = "UA-36890222-9", analytics = "UA-36890222-9",
sitename = "Flux", sitename = "Flux",
assets = ["../flux.css"], # Uncomment below for local build
#format = Documenter.HTML(prettyurls = false),
assets = ["assets/flux.css"],
pages = ["Home" => "index.md", pages = ["Home" => "index.md",
"Building Models" => "Building Models" =>
["Basics" => "models/basics.md", ["Basics" => "models/basics.md",
@ -22,10 +23,4 @@ makedocs(modules=[Flux, NNlib],
["Backpropagation" => "internals/tracker.md"], ["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"]) "Community" => "community.md"])
deploydocs( deploydocs(repo = "github.com/FluxML/Flux.jl.git")
repo = "github.com/FluxML/Flux.jl.git",
target = "build",
osname = "linux",
julia = "1.0",
deps = nothing,
make = nothing)

113
docs/src/assets/flux.css Normal file
View File

@ -0,0 +1,113 @@
@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');
body {
font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
}
nav.toc {
padding-top: 0;
background: rgb(240, 240, 240);
line-height: 2em;
cursor: default;
user-select: none;
}
h1+h2 {
margin-top: 0;
}
/* Green banner in ToC */
nav.toc > h1 {
margin-top: 0;
padding-top: 0.4em;
padding-bottom: 0.5em;
border-bottom: 5px solid white;
box-shadow: 0px -2px 5px rgb(60,60,60);
margin-bottom: 0.5em;
background: rgb(60, 150, 60);
font-style: italic;
font-weight: normal;
font-size: 50pt;
text-transform: lowercase;
text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
color: white;
}
/* Reduce ToC font size */
.toctext {
font-size: 10pt;
}
/* Fade out non-clickable ToC headers */
nav.toc ul span.toctext {
color: rgb(180, 180, 180);
}
nav.toc ul .toctext {
color: rgb(100, 100, 100);
}
nav.toc ul a.toctext:hover {
color: inherit;
background: rgb(220, 220, 220);
cursor: default;
}
nav.toc li.current > .toctext {
background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
font-weight: normal;
}
nav.toc ul.internal li.toplevel {
font-weight: normal;
}
/* Content */
article { max-width: none; }
article > p, article > ul {
max-width: 45em;
}
/* Links */
a, a:visited { color: rgb(0, 120, 0); }
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
a:hover, a:visited:hover { color: rgb(0, 80, 0); }
/* Article Links */
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }
/* Doctstrings */
article section.docstring {
padding: 0.5em 0;
border-left: none;
border-right: none;
border-bottom: none;
}
/* Code */
article pre, article p > code {
background: rgb(245, 250, 245);
}
article pre {
border: none;
max-width: none;
padding: 1em;
border-radius: 10px 0px 0px 10px;
margin-left: -1em;
margin-right: -2em;
}
.hljs-comment {
font-style: italic;
}
.hljs-number {
color: rgb(0, 150, 150);
}

View File

@ -1,10 +1,22 @@
# GPU Support # GPU Support
## Installation
To get GPU support for NVIDIA graphics cards, you need to install `CuArrays.jl`
**Steps needed**
1. Install [NVIDIA toolkit](https://developer.nvidia.com/cuda-downloads)
2. Install [NVIDIA cuDNN library](https://developer.nvidia.com/cudnn)
3. In Julia's terminal run `]add CuArrays`
## GPU Usage
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it. Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.) (Note that you need to have CUDA available to use CuArrays please see the [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl) instructions for more details.)
```julia ```julia
using CuArrays using CuArrays

View File

@ -10,12 +10,12 @@ using Flux.Tracker
f(x) = 3x^2 + 2x + 1 f(x) = 3x^2 + 2x + 1
# df/dx = 6x + 2 # df/dx = 6x + 2
df(x) = Tracker.gradient(f, x)[1] df(x) = Tracker.gradient(f, x; nest = true)[1]
df(2) # 14.0 (tracked) df(2) # 14.0 (tracked)
# d²f/dx² = 6 # d²f/dx² = 6
d2f(x) = Tracker.gradient(df, x)[1] d2f(x) = Tracker.gradient(df, x; nest = true)[1]
d2f(2) # 6.0 (tracked) d2f(2) # 6.0 (tracked)
``` ```
@ -28,10 +28,10 @@ When a function has many parameters, we can pass them all in explicitly:
f(W, b, x) = W * x + b f(W, b, x) = W * x + b
Tracker.gradient(f, 2, 3, 4) Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0, 2.0 (tracked)) # (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
``` ```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once. But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
```julia ```julia
W = param(2) # 2.0 (tracked) W = param(2) # 2.0 (tracked)
@ -39,14 +39,13 @@ b = param(3) # 3.0 (tracked)
f(x) = W * x + b f(x) = W * x + b
params = Params([W, b]) grads = Tracker.gradient(() -> f(4), params(W, b))
grads = Tracker.gradient(() -> f(4), params)
grads[W] # 4.0 grads[W] # 4.0
grads[b] # 1.0 grads[b] # 1.0
``` ```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate. There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple. This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
@ -77,7 +76,7 @@ using Flux.Tracker
W = param(W) W = param(W)
b = param(b) b = param(b)
gs = Tracker.gradient(() -> loss(x, y), Params([W, b])) gs = Tracker.gradient(() -> loss(x, y), params(W, b))
``` ```
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent. Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
@ -102,6 +101,8 @@ All deep learning in Flux, however complex, is a simple generalisation of this e
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as: It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as:
```julia ```julia
using Flux
W1 = param(rand(3, 5)) W1 = param(rand(3, 5))
b1 = param(rand(3)) b1 = param(rand(3))
layer1(x) = W1 * x .+ b1 layer1(x) = W1 * x .+ b1

View File

@ -3,7 +3,7 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
```julia ```julia
using Flux.Tracker using Flux, Flux.Tracker
W = param(rand(2, 5)) W = param(rand(2, 5))
b = param(rand(2)) b = param(rand(2))
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3 l = loss(x, y) # ~ 3
params = Params([W, b]) θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params) grads = Tracker.gradient(() -> loss(x, y), θ)
``` ```
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
@ -23,44 +23,30 @@ We want to update each parameter, using the gradient, in order to improve (reduc
```julia ```julia
using Flux.Tracker: grad, update! using Flux.Tracker: grad, update!
function sgd() η = 0.1 # Learning Rate
η = 0.1 # Learning Rate for p in (W, b)
for p in (W, b)
update!(p, -η * grads[p]) update!(p, -η * grads[p])
end
end end
``` ```
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down. Running this will alter the parameters `W` and `b` and our loss should go down. Flux provides a more general way to do optimiser updates like this.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
In this case, getting the variables is trivial, but you can imagine it'd be more of a pain with some complex stack of layers.
```julia ```julia
m = Chain( opt = Descent(0.1) # Gradient descent with learning rate 0.1
Dense(10, 5, σ),
Dense(5, 2), softmax) for p in (W, b)
update!(opt, p, grads[p])
end
``` ```
Instead of having to write `[m[1].W, m[1].b, ...]`, Flux provides a params function `params(m)` that returns a list of all parameters in the model for you. An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `ADAM`.
For the update step, there's nothing whatsoever wrong with writing the loop above it'll work just fine but Flux provides various *optimisers* that make it more convenient.
```julia
opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
opt() # Carry out the update, modifying `W` and `b`.
```
An optimiser takes a parameter list and returns a function that does the same thing as `update` above. We can pass either `opt` or `update` to our [training loop](training.md), which will then run the optimiser after every mini-batch of data.
## Optimiser Reference ## Optimiser Reference
All optimisers return a function that, when called, will update the parameters passed to it. All optimisers return an object that, when passed to `train!`, will update the parameters passed to it.
```@docs ```@docs
SGD Descent
Momentum Momentum
Nesterov Nesterov
ADAM ADAM

View File

@ -9,7 +9,7 @@ To actually train a model we need three things:
With these we can call `Flux.train!`: With these we can call `Flux.train!`:
```julia ```julia
Flux.train!(objective, data, opt) Flux.train!(objective, params, data, opt)
``` ```
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo). There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
@ -24,9 +24,10 @@ m = Chain(
Dense(32, 10), softmax) Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y) loss(x, y) = Flux.mse(m(x), y)
ps = Flux.params(m)
# later # later
Flux.train!(loss, data, opt) Flux.train!(loss, ps, data, opt)
``` ```
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want. The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
@ -78,7 +79,7 @@ julia> @epochs 2 Flux.train!(...)
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
```julia ```julia
train!(objective, data, opt, cb = () -> println("training")) train!(objective, ps, data, opt, cb = () -> println("training"))
``` ```
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
@ -89,6 +90,6 @@ A more typical callback might look like this:
test_x, test_y = # ... create single batch of test data ... test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y)) evalcb() = @show(loss(test_x, test_y))
Flux.train!(objective, data, opt, Flux.train!(objective, ps, data, opt,
cb = throttle(evalcb, 5)) cb = throttle(evalcb, 5))
``` ```

View File

@ -8,7 +8,7 @@ using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm, DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib @reexport using NNlib

View File

@ -1,6 +1,22 @@
module CUDA module CUDA
using ..CuArrays using ..CuArrays
using Pkg.TOML
function version_check()
minor_version = 9
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0"))
if !(version.major == 0 && version.minor == minor_version)
@warn """
Flux is only supported with CuArrays v0.$minor_version.
Try running `] pin CuArrays@0.$minor_version`.
"""
end
end
version_check()
if !applicable(CuArray{UInt8}, undef, 1) if !applicable(CuArray{UInt8}, undef, 1)
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...) (T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)

View File

@ -21,8 +21,8 @@ struct Chain{T<:Tuple}
Chain(xs...) = new{typeof(xs)}(xs) Chain(xs...) = new{typeof(xs)}(xs)
end end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
@forward Chain.layers Base.iterate Base.iterate, Base.lastindex
children(c::Chain) = c.layers children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...) mapchildren(f, c::Chain) = Chain(f.(c.layers)...)

View File

@ -132,12 +132,12 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} = stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...) DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn, DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
stride = 1, pad = 0) where N = stride = 1, pad = 0) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ, DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad) stride = stride, pad = pad)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
stride::NTuple{N,Integer} = map(_->1,k), stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N = pad::NTuple{N,Integer} = map(_->0,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ, DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,

View File

@ -106,7 +106,7 @@ mutable struct BatchNorm{F,V,W,N}
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum, true)

View File

@ -2,16 +2,14 @@ using NNlib: logsoftmax, logσ
# Cost functions # Cost functions
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2) * 1 // length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2) -sum(y .* log.() .* weight) * 1 // size(y, 2)
end end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2) return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
end end
""" """

View File

@ -68,3 +68,6 @@ end
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b) a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b) a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)

View File

@ -4,7 +4,7 @@ using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule # legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps) updaterule(opt, ps) = () -> _update_params!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.) function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
@ -117,7 +117,7 @@ struct OldOptimiser
func func
end end
update!(opt::OldOptimiser, ps) = opt.func() _update_params!(opt::OldOptimiser, ps) = opt.func()
# Train function # Train function
function train!(loss, data, opt; cb = () -> ()) function train!(loss, data, opt; cb = () -> ())

View File

@ -18,7 +18,7 @@ end
Descent() = Descent(0.1) Descent() = Descent(0.1)
function update!(o::Descent, x, Δ) function apply!(o::Descent, x, Δ)
Δ .*= o.eta Δ .*= o.eta
end end
@ -35,7 +35,7 @@ end
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ) function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ @. v = ρ * v - η * Δ
@ -55,7 +55,7 @@ end
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ) function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(x)
d = @. ρ^2 * v - (1+ρ) * η * Δ d = @. ρ^2 * v - (1+ρ) * η * Δ
@ -78,7 +78,7 @@ end
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function update!(o::RMSProp, x, Δ) function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x) acc = get!(o.acc, x, zero(x))::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@ -98,7 +98,7 @@ end
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
function update!(o::ADAM, x, Δ) function apply!(o::ADAM, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β)) mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -122,7 +122,7 @@ end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ) function apply!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β)) mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -145,7 +145,7 @@ end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ) function apply!(o::ADAGrad, x, Δ)
η = o.eta η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2 @. acc += Δ^2
@ -165,7 +165,7 @@ end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ) function apply!(o::ADADelta, x, Δ)
ρ = o.rho ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x))) acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@ -188,7 +188,7 @@ end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ) function apply!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x)))) mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -211,7 +211,7 @@ end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ) function apply!(o::NADAM, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
β1p, β2p = o.beta β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x))) mt, vt = get!(o.state, x, (zero(x), zero(x)))
@ -228,7 +228,7 @@ end
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
""" """
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(wd)) Optimiser(ADAM(η, β), WeightDecay(decay))
# Compose optimizers # Compose optimizers
@ -250,13 +250,21 @@ Optimiser(o...) = Optimiser(Any[o...])
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ) function apply!(o::Optimiser, x, Δ)
for opt in o.os for opt in o.os
Δ = update!(opt, x, Δ) Δ = apply!(opt, x, Δ)
end end
return Δ return Δ
end end
"""
`InvDecay(γ)`
Apply inverse time decay to an optimiser
```julia
Optimiser(InvDecay(..), Opt(..))
```
"""
mutable struct InvDecay mutable struct InvDecay
gamma::Float64 gamma::Float64
state::IdDict state::IdDict
@ -264,7 +272,7 @@ end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ) function apply!(o::InvDecay, x, Δ)
γ = o.gamma γ = o.gamma
n = get!(o.state, x, 1) n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n) Δ .*= 1 / (1 + γ * n)
@ -272,6 +280,16 @@ function update!(o::InvDecay, x, Δ)
return Δ return Δ
end end
"""
`ExpDecay(eta, decay, decay_step, clip)`
Schedule the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
To apply exponential decay to an optimiser:
```julia
Optimiser(ExpDecay(..), Opt(..))
```
"""
mutable struct ExpDecay mutable struct ExpDecay
eta::Float64 eta::Float64
decay::Float64 decay::Float64
@ -282,7 +300,7 @@ end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ) function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1 n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
@ -292,13 +310,18 @@ function update!(o::ExpDecay, x, Δ)
@. Δ *= decay @. Δ *= decay
end end
"""
`WeightDecay(wd)`
Decay the weight parameter by `wd`
"""
mutable struct WeightDecay mutable struct WeightDecay
wd::Real wd::Real
end end
WeightDecay() = WeightDecay(0) WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ) function apply!(o::WeightDecay, x, Δ)
wd = o.wd wd = o.wd
@. Δ += wd * x @. Δ += wd * x
end end

View File

@ -1,10 +1,14 @@
using Juno using Juno
using Flux.Tracker: data, grad, back! import Flux.Tracker: data, grad, back!, update!
import Base.depwarn import Base.depwarn
function update!(opt, xs) function update!(opt, x, )
update!(x, apply!(opt, x, copy(data())))
end
function _update_params!(opt, xs)
for x in xs for x in xs
Δ = update!(opt, x.data, x.grad) Δ = apply!(opt, x.data, x.grad)
x.data .-= Δ x.data .-= Δ
Δ .= 0 Δ .= 0
end end
@ -45,7 +49,7 @@ function stop()
end end
""" """
train!(model, loss, data, opt) train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`. backpropagation and calls the optimizer `opt`.
@ -54,11 +58,11 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds: every 10 seconds:
```julia ```julia
Flux.train!(model, loss, data, opt, Flux.train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10)) cb = throttle(() -> println("training"), 10))
``` ```
The callback can return `:stop` to interrupt the training loop. The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
@ -69,7 +73,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
try try
l = loss(d...) l = loss(d...)
@interrupts back!(l) @interrupts back!(l)
update!(opt, ps) _update_params!(opt, ps)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break

View File

@ -6,7 +6,7 @@ using MacroTools: @q, @forward
import Base: == import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient, export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back! jacobian, hessian, param, back!
tracker(x) = nothing tracker(x) = nothing
@ -61,24 +61,20 @@ macro grad(ex)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end end
function update!(x, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
include("idset.jl") include("idset.jl")
include("back.jl") include("back.jl")
include("numeric.jl") include("numeric.jl")
include("lib/real.jl") include("lib/real.jl")
include("lib/array.jl") include("lib/array.jl")
include("forward.jl")
""" """
hook(f, x) -> x hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse `f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.""" the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ)) @grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))

View File

@ -67,7 +67,7 @@ function back!(x, Δ; once = true)
end end
function gradient_(f, xs...) function gradient_(f, xs...)
xs = param.(xs) xs = param.(data.(xs))
l = f(xs...) l = f(xs...)
losscheck(l) losscheck(l)
back!(l) back!(l)
@ -179,3 +179,30 @@ end
gradient(f, xs...; nest = false) = gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...) nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)
# Jacobians and Hessians
import ..Flux
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,k,n)
for i = 1:k
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[i,:] = xp.grad
xp.grad .= 0 # Reset gradient accumulator
end
J
end
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)

53
src/tracker/forward.jl Normal file
View File

@ -0,0 +1,53 @@
using ForwardDiff
seed(x::Real, ::Val) = Dual(x, true)
function seed(x, ::Val{N}, offset = 0) where N
map(x, reshape(1:length(x), size(x))) do x, i
Dual(x, ntuple(j -> j+offset == i, Val(N)))
end
end
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
J = similar(xs, V, N, length(xs))
for i = 1:length(xs), j = 1:N
J[j, i] = xs[i].partials.values[j]
end
return map(x -> x.value, xs), J
end
function forward_jacobian(f, x, ::Val{N}) where N
y, _J = extract(f(seed(x, Val(N))))
J = similar(_J, length(x), length(y))
J[1:N,:] = _J
offset = 0
while offset + N < length(x)
offset += N
_, _J = extract(f(seed(x, Val(N), offset)))
range = (1+offset):min(N+offset,length(x))
J[range,:] = @view _J[range.-offset,:]
end
return y, J
end
function forward_jacobian(f, x)
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
forward_jacobian(f, x, Val(length(x)))
else
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
end
end
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
vec_scalar(x) = vec(x)
vec_scalar(x::Real) = [x]
reshape_scalar(x, y) = reshape(y, size(x))
reshape_scalar(x::Real, y) = y[]
@grad function forwarddiff(f, x)
y, J = forward_jacobian(f, data(x))
return y, -> (nothing, reshape_scalar(x, J*vec_scalar()))
end

View File

@ -65,6 +65,12 @@ Base.setindex!(xs::TrackedArray, v, i...) =
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`") back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
function update!(x::TrackedArray, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
# Fallthrough methods # Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args for f in :[Base.size, Base.ndims, Base.collect].args
@ -115,8 +121,8 @@ Base.:-(xs::TrackedArray) = track(-, xs)
Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs) Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) @grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) @grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
@ -136,30 +142,28 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
end end
end end
for f in [:vcat, :hcat] function combinations(xs, n)
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose}) n < 1 && return [[]]
@eval begin cs = combinations(xs, n-1)
# This section is a bit of a hack since julia doesn't have a standardised [[x, c...] for x in xs, c in cs]
# promotion mechanism for concatenation yet end
# https://github.com/JuliaLang/julia/pull/20815
# It should support tracked concatenation with rank ∈ (1,2) with a for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
# TrackedArray anywhere among the arguments This works as long as base has cnames = map(_ -> gensym(), c)
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
Base.$f(a::$UArray...) = track($f, a...) track($f, $(cnames...), x, xs...)
end
# It should support tracked concatenation with rank>2 if the TrackedArray is for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
# first cnames = map(_ -> gensym(), c)
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row track($f, $(cnames...), x, xs...)
end
# It should support tracked concatenation with rank>2 if the TrackedArray is for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
# second cnames = map(_ -> gensym(), c)
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray, track($f, $(cnames...), x, xs...)
c::$UArray...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row
end
end end
@grad function vcat(xs...) @grad function vcat(xs...)
@ -192,10 +196,11 @@ end
end end
end end
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims) for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) cnames = map(_ -> gensym(), c)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) @eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) track(cat, $(cnames...), x, xs..., dims = dims)
end
@grad function cat(Xs...; dims) @grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ) cat(data.(Xs)..., dims = dims), function (Δ)
@ -218,8 +223,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs,
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing) @grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims) Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing) @grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1) m1, n1 = size(mat1)

View File

@ -1,4 +1,4 @@
struct TrackedReal{T<:Real} <: Real mutable struct TrackedReal{T<:Real} <: Real
data::T data::T
tracker::Tracked{T} tracker::Tracked{T}
end end
@ -16,6 +16,12 @@ function back!(x::TrackedReal; once = true)
return back!(x, 1, once = once) return back!(x, 1, once = once)
end end
function update!(x::TrackedReal, Δ)
x.data += data(Δ)
tracker(x).grad = 0
return x
end
function Base.show(io::IO, x::TrackedReal) function Base.show(io::IO, x::TrackedReal)
T = get(io, :typeinfo, Any) T = get(io, :typeinfo, Any)
show(io, data(x)) show(io, data(x))
@ -33,6 +39,8 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T") error("Not implemented: convert tracked $S to tracked $T")
(T::Type{<:TrackedReal})(x::Real) = convert(T, x)
for op in [:(==), :≈, :<, :(<=)] for op in [:(==), :≈, :<, :(<=)]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y) @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y)) @eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@ -46,11 +54,19 @@ for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x)) @eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end end
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n) Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
Base.float(x::TrackedReal) = x
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)} TrackedReal{promote_type(S,T)}
using Random
for f in :[rand, randn, randexp].args
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
end
using DiffRules, SpecialFunctions, NaNMath using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
@ -85,6 +101,13 @@ import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b) ^(a::TrackedReal, b::Integer) = track(^, a, b)
# Hack for conversions
using ForwardDiff: Dual
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
(Dual{T,V,N})(x::Dual) where {T,V,N} = invoke(Dual{T,V,N}, Tuple{Number}, x)
# Tuples # Tuples
struct TrackedTuple{T<:Tuple} struct TrackedTuple{T<:Tuple}

View File

@ -1,4 +1,4 @@
import Adapt: adapt import Adapt: adapt, adapt_storage
import .Tracker: IdSet import .Tracker: IdSet
children(x) = () children(x) = ()
@ -14,11 +14,6 @@ function treelike(m::Module, T, fs = fieldnames(T))
end end
end end
function treelike(T, fs = fieldnames(T))
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
treelike(Base._current_module(), T, fs)
end
macro treelike(T, fs = nothing) macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)") fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
@ -69,3 +64,22 @@ gpu_adaptor = identity
end end
gpu(x) = mapleaves(gpu_adaptor, x) gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
x isa Union{AbstractArray,Number} ? f(x) :
x
end
end

View File

@ -10,8 +10,8 @@ zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
""" """
chunk(xs, n) chunk(xs, n)
@ -139,25 +139,6 @@ function throttle(f, timeout; leading=true, trailing=false)
end end
end end
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .= 0 # Reset gradient accumulator
end
J'
end
""" """
@jit ... @jit ...

View File

@ -11,6 +11,8 @@ x = param(randn(5, 5))
cx = gpu(x) cx = gpu(x)
@test cx isa TrackedArray && cx.data isa CuArray @test cx isa TrackedArray && cx.data isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3) x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x) cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray @test cx isa Flux.OneHotMatrix && cx.data isa CuArray

View File

@ -21,3 +21,15 @@ end
@test size(m(r)) == (10, 5) @test size(m(r)) == (10, 5)
end end
@testset "Depthwise Conv" begin
r = zeros(Float32, 28, 28, 3, 5)
m1 = DepthwiseConv((2, 2), 3=>5)
@test size(m1(r), 3) == 15
m2 = DepthwiseConv((2, 2), 3)
@test size(m2(r), 3) == 3
end

View File

@ -49,4 +49,16 @@ const ϵ = 1e-7
@testset "logitbinarycrossentropy" begin @testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0) @test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end end
@testset "no spurious promotions" begin
for T in (Float16, Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.Tracker.forward(mse, , y)
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
end
end
end
end end

View File

@ -4,7 +4,7 @@ using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum] @testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001) opt = Opt(0.001)
@ -17,7 +17,7 @@ using Test
for t = 1: 10^5 for t = 1: 10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
delta = Optimise.update!(opt, w.data, w.grad) delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
@ -33,7 +33,7 @@ end
for t = 1:10^5 for t = 1:10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
delta = Optimise.update!(opt, w.data, w.grad) delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01

View File

@ -1,18 +1,3 @@
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
if Base.JLOptions().check_bounds == 1
file = @__FILE__
run(```
$(Base.julia_cmd())
--color=$(Base.have_color ? "yes" : "no")
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
$(file)
```)
exit()
end
using Flux, Test, Random, Statistics using Flux, Test, Random, Statistics
using Random using Random

View File

@ -1,6 +1,6 @@
using Flux using Flux
using Flux.Tracker, Test, NNlib using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
using NNlib: conv, ∇conv_data, depthwiseconv using NNlib: conv, ∇conv_data, depthwiseconv
using Printf: @sprintf using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm using LinearAlgebra: diagm, dot, LowerTriangular, norm
@ -42,12 +42,7 @@ function promotiontest(f, A, B, C)
r0 = f(A, B, C) r0 = f(A, B, C)
r1 = f(param(A), B, C) r1 = f(param(A), B, C)
r2 = f(A, param(B), C) r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat]
r3 = f(A, B, param(C)) r3 = f(A, B, param(C))
else
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end
r4 = f(param(A), param(B), param(C)) r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray) @test !isa(r0, TrackedArray)
@ -121,6 +116,7 @@ end
end end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repeat(x; inner=2), rand(5)) @test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5)) @test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@ -202,6 +198,8 @@ end
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) @test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test gradtest(x -> Float64.(x), 5)
@testset "equality & order" begin @testset "equality & order" begin
# TrackedReal # TrackedReal
@test param(2)^2 == param(4) @test param(2)^2 == param(4)
@ -273,7 +271,7 @@ Tracker.back!(b)
back!(z) back!(z)
@test grad.((x,y)) == (3, 2) @test grad.((x,y)) == (3, 2)
@test Tracker.gradient(2, 3) do x, y @test gradient(2, 3) do x, y
xy = Tracker.collect([x, y]) xy = Tracker.collect([x, y])
xy[1]*xy[2] xy[1]*xy[2]
end == (3, 2) end == (3, 2)
@ -299,4 +297,31 @@ end
@test count == 3 @test count == 3
end end
@testset "Updates" begin
xs = param([1, 2, 3])
Tracker.update!(xs, param([4, 5, 6]))
@test xs == [5, 7, 9]
x = param(3)
Tracker.update!(x, param(4))
@test x == 7
end
@testset "Params" begin
W = param(randn(5, 10))
x = rand(10)
dW = gradient(W -> sum(W*x), W)[1]
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
@test gs[W] == dW
end
@testset "Forward" begin
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
reshape(ones(25), :, 1)
@test gradient([2, 3]) do x
forwarddiff(x) do x
x[1]*x[2]
end
end == ([3, 2],)
end
end #testset end #testset

View File

@ -1,5 +1,5 @@
using Flux using Flux
using Flux: throttle, jacobian, glorot_uniform, glorot_normal using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack
using StatsBase: std using StatsBase: std
using Random using Random
using Test using Test
@ -86,3 +86,22 @@ end
m = RNN(10, 5) m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end end
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
@test eltype(m[1].W.data) == Float32
@test eltype(m(x).data) == Float32
@test eltype(f64(m)(x).data) == Float64
@test eltype(f64(m)[1].W.data) == Float64
@test eltype(f32(f64(m))[1].W.data) == Float32
@test Tracker.isleaf(f32(f64(m))[1].W)
end
@testset "Stacking" begin
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unstack(stacked_array, 2) == unstacked_array
@test stack(unstacked_array, 2) == stacked_array
@test stack(unstack(stacked_array, 1), 1) == stacked_array
end