# Documentation:
language: julia
- linux
# - osx
- 1.0
- nightly
# uncomment the following lines to override the default test script
# script:
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
# - julia -e 'Pkg.clone(pwd());"Flux"); Pkg.test("Flux"; coverage=true)'
- julia: nightly
- julia -e 'using Pkg; ps=Pkg.PackageSpec(name="Documenter", version="0.19"); Pkg.add(ps);; Pkg.add("NNlib")'
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
- stage: "Documentation"
julia: 1.0
os: linux
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
- julia --project=docs/ docs/make.jl
after_success: skip
## uncomment the following lines to override the default test script
- julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()'

@ -1,5 +1,3 @@
# This file is machine-generated - editing it directly is not advised
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
@ -8,9 +6,9 @@ version = "0.2.1"
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1"
version = "0.4.2"
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -53,15 +51,15 @@ version = "0.2.0"
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
version = "1.5.1"
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "8fc6e166e24fda04b2b648d4260cdad241788c54"
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.14.0"
version = "0.15.0"
deps = ["Printf"]
@ -79,12 +77,12 @@ version = "0.0.3"
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
version = "0.0.8"
deps = ["Random", "Serialization", "Sockets"]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@ -95,19 +93,19 @@ version = "0.5.3"
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1"
version = "0.10.2"
deps = ["Markdown"]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3"
version = "0.5.4"
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -140,18 +138,20 @@ version = "0.5.0"
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1"
version = "0.4.0"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f"
repo-rev = "master"
repo-url = ""
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
version = "0.4.3+"
deps = ["Compat"]
@ -228,19 +228,19 @@ version = "0.7.2"
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "97c4bf0f647488dd7ac01ea12be5885f88762938"
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.0"
version = "0.10.2"
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "2722397d88f8ffef551948f6c20e1d74a743298c"
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.26.0"
version = "0.27.0"
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
@ -259,7 +259,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
deps = ["Random", "SHA"]
deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

@ -13,10 +13,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

docs/Manifest.toml Normal file
@ -0,0 +1,288 @@
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1"
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.6.0"
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.21.0"
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
path = ".."
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.6.10+"
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1"
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3"
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"

@ -0,0 +1,4 @@
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

@ -1,11 +1,12 @@
using Documenter, Flux, NNlib
makedocs(modules=[Flux, NNlib],
doctest = false,
format = :html,
doctest = true,
analytics = "UA-36890222-9",
sitename = "Flux",
assets = ["../flux.css"],
# Uncomment below for local build
#format = Documenter.HTML(prettyurls = false),
assets = ["assets/flux.css"],
pages = ["Home" => "",
"Building Models" =>
["Basics" => "models/",
@ -22,10 +23,4 @@ makedocs(modules=[Flux, NNlib],
["Backpropagation" => "internals/"],
"Community" => ""])
repo = "",
target = "build",
osname = "linux",
julia = "1.0",
deps = nothing,
make = nothing)
deploydocs(repo = "")

@ -0,0 +1,113 @@
@import url(',400i');
body {
font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
nav.toc {
padding-top: 0;
background: rgb(240, 240, 240);
line-height: 2em;
cursor: default;
user-select: none;
h1+h2 {
margin-top: 0;
/* Green banner in ToC */
nav.toc > h1 {
margin-top: 0;
padding-top: 0.4em;
padding-bottom: 0.5em;
border-bottom: 5px solid white;
box-shadow: 0px -2px 5px rgb(60,60,60);
margin-bottom: 0.5em;
background: rgb(60, 150, 60);
font-style: italic;
font-weight: normal;
font-size: 50pt;
text-transform: lowercase;
text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
color: white;
/* Reduce ToC font size */
.toctext {
font-size: 10pt;
/* Fade out non-clickable ToC headers */
nav.toc ul span.toctext {
color: rgb(180, 180, 180);
nav.toc ul .toctext {
color: rgb(100, 100, 100);
nav.toc ul a.toctext:hover {
color: inherit;
background: rgb(220, 220, 220);
cursor: default;
nav.toc li.current > .toctext {
background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
font-weight: normal;
nav.toc ul.internal li.toplevel {
font-weight: normal;
/* Content */
article { max-width: none; }
article > p, article > ul {
max-width: 45em;
/* Links */
a, a:visited { color: rgb(0, 120, 0); }
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
a:hover, a:visited:hover { color: rgb(0, 80, 0); }
/* Article Links */
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }
/* Doctstrings */
article section.docstring {
padding: 0.5em 0;
border-left: none;
border-right: none;
border-bottom: none;
/* Code */
article pre, article p > code {
background: rgb(245, 250, 245);
article pre {
border: none;
max-width: none;
padding: 1em;
border-radius: 10px 0px 0px 10px;
margin-left: -1em;
margin-right: -2em;
.hljs-comment {
font-style: italic;
.hljs-number {
color: rgb(0, 150, 150);

@ -1,10 +1,22 @@
# GPU Support
## Installation
To get GPU support for NVIDIA graphics cards, you need to install `CuArrays.jl`
**Steps needed**
1. Install [NVIDIA toolkit](
2. Install [NVIDIA cuDNN library](
3. In Julia's terminal run `]add CuArrays`
## GPU Usage
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays]( Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/ on an NVIDIA GPU.
(Note that you need to have CUDA available to use CuArrays please see the [CuArrays.jl ( instructions for more details.)
(Note that you need to have CUDA available to use CuArrays please see the [CuArrays.jl]( instructions for more details.)
using CuArrays

View File

@ -4,49 +4,56 @@
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
using Flux.Tracker
```jldoctest basics
julia> using Flux.Tracker
f(x) = 3x^2 + 2x + 1
julia> f(x) = 3x^2 + 2x + 1;
# df/dx = 6x + 2
df(x) = Tracker.gradient(f, x)[1]
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
df(2) # 14.0 (tracked)
julia> df(2)
14.0 (tracked)
# d²f/dx² = 6
d2f(x) = Tracker.gradient(df, x)[1]
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
d2f(2) # 6.0 (tracked)
julia> d2f(2)
6.0 (tracked)
(We'll learn more about why these numbers show up as `(tracked)` below.)
When a function has many parameters, we can pass them all in explicitly:
f(W, b, x) = W * x + b
```jldoctest basics
julia> f(W, b, x) = W * x + b;
Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0, 2.0 (tracked))
julia> Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
W = param(2) # 2.0 (tracked)
b = param(3) # 3.0 (tracked)
```jldoctest basics
julia> using Flux
f(x) = W * x + b
julia> W = param(2)
2.0 (tracked)
params = Params([W, b])
grads = Tracker.gradient(() -> f(4), params)
julia> b = param(3)
3.0 (tracked)
grads[W] # 4.0
grads[b] # 1.0
julia> f(x) = W * x + b;
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
julia> grads[W]
julia> grads[b]
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
@ -77,7 +84,7 @@ using Flux.Tracker
W = param(W)
b = param(b)
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
@ -102,6 +109,8 @@ All deep learning in Flux, however complex, is a simple generalisation of this e
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid]( (`σ`) in between them. In the above style we could write this as:
using Flux
W1 = param(rand(3, 5))
b1 = param(rand(3))
layer1(x) = W1 * x .+ b1

@ -14,6 +14,7 @@ MeanPool
## Recurrent Layers

View File

@ -3,7 +3,7 @@
Consider a [simple linear regression](../models/ We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
using Flux.Tracker
using Flux, Flux.Tracker
W = param(rand(2, 5))
b = param(rand(2))
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3
params = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params)
θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), θ)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
@ -23,44 +23,30 @@ We want to update each parameter, using the gradient, in order to improve (reduc
using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate
for p in (W, b)
update!(p, -η * grads[p])
η = 0.1 # Learning Rate
for p in (W, b)
update!(p, -η * grads[p])
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
In this case, getting the variables is trivial, but you can imagine it'd be more of a pain with some complex stack of layers.
Running this will alter the parameters `W` and `b` and our loss should go down. Flux provides a more general way to do optimiser updates like this.
m = Chain(
Dense(10, 5, σ),
Dense(5, 2), softmax)
opt = Descent(0.1) # Gradient descent with learning rate 0.1
for p in (W, b)
update!(opt, p, grads[p])
Instead of having to write `[m[1].W, m[1].b, ...]`, Flux provides a params function `params(m)` that returns a list of all parameters in the model for you.
For the update step, there's nothing whatsoever wrong with writing the loop above it'll work just fine but Flux provides various *optimisers* that make it more convenient.
opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
opt() # Carry out the update, modifying `W` and `b`.
An optimiser takes a parameter list and returns a function that does the same thing as `update` above. We can pass either `opt` or `update` to our [training loop](, which will then run the optimiser after every mini-batch of data.
An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](, which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `ADAM`.
## Optimiser Reference
All optimisers return a function that, when called, will update the parameters passed to it.
All optimisers return an object that, when passed to `train!`, will update the parameters passed to it.

@ -9,7 +9,7 @@ To actually train a model we need three things:
With these we can call `Flux.train!`:
Flux.train!(objective, data, opt)
Flux.train!(objective, params, data, opt)
There are plenty of examples in the [model zoo](
@ -24,9 +24,10 @@ m = Chain(
Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y)
ps = Flux.params(m)
# later
Flux.train!(loss, data, opt)
Flux.train!(loss, ps, data, opt)
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
@ -78,7 +79,7 @@ julia> @epochs 2 Flux.train!(...)
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
train!(objective, data, opt, cb = () -> println("training"))
train!(objective, ps, data, opt, cb = () -> println("training"))
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
@ -89,6 +90,6 @@ A more typical callback might look like this:
test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
Flux.train!(objective, data, opt,
Flux.train!(objective, ps, data, opt,
cb = throttle(evalcb, 5))

@ -6,9 +6,9 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib

View File

@ -1,6 +1,22 @@
module CUDA
using ..CuArrays
using Pkg.TOML
function version_check()
minor_version = 9
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0"))
if !(version.major == 0 && version.minor == minor_version)
@warn """
Flux is only supported with CuArrays v0.$minor_version.
Try running `] pin CuArrays@0.$minor_version`.
if !applicable(CuArray{UInt8}, undef, 1)
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)

@ -1,11 +1,27 @@
module Data
import ..Flux
import SHA
export CMUDict, cmudict
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
function download_and_verify(url, path, hash)
tmppath = tempname()
download(url, tmppath)
hash_download = open(tmppath) do f
if hash_download !== hash
msg = "Hash Mismatch!\n"
msg *= " Expected sha256: $hash\n"
msg *= " Calculated sha256: $hash_download"
mv(tmppath, path; force=true)
function __init__()

@ -2,23 +2,25 @@ module CMUDict
export cmudict
using ..Data: deps
using ..Data: deps, download_and_verify
const version = "0.7b"
const cache_prefix = ""
function load()
suffixes = ["", ".phones", ".symbols"]
suffixes_and_hashes = [("" , "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"),
(".phones" , "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"),
(".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027")]
if isdir(deps("cmudict"))
if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes)
if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes)
@info "Downloading CMUDict dataset"
for x in suffixes
deps("cmudict", "cmudict$x"))
for (x, hash) in suffixes_and_hashes
deps("cmudict", "cmudict$x"), hash)

@ -1,19 +1,20 @@
module FashionMNIST
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
using ..Data: download_and_verify
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
function load()
cd(dir) do
for file in ["train-images-idx3-ubyte",
for (file, hash) in [("train-images-idx3-ubyte", "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"),
("train-labels-idx1-ubyte", "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"),
("t10k-images-idx3-ubyte" , "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"),
("t10k-labels-idx1-ubyte" , "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5")]
isfile(file) && continue
@info "Downloading Fashion-MNIST dataset"
download("$file.gz", "$file.gz")
download_and_verify("$file.gz", "$file.gz", hash)
open(file, "w") do io
write(io, gzopen(read, "$file.gz"))

@ -1,6 +1,7 @@
module MNIST
using CodecZlib, Colors
using ..Data: download_and_verify
const Gray = Colors.Gray{Colors.N0f8}
@ -15,13 +16,13 @@ end
function load()
cd(dir) do
for file in ["train-images-idx3-ubyte",
for (file, hash) in [("train-images-idx3-ubyte", "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"),
("train-labels-idx1-ubyte", "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"),
("t10k-images-idx3-ubyte" , "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"),
("t10k-labels-idx1-ubyte" , "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6")]
isfile(file) && continue
@info "Downloading MNIST dataset"
download("$file.gz", "$file.gz")
download_and_verify("$file.gz", "$file.gz", hash)
open(file, "w") do io
write(io, gzopen(read, "$file.gz"))

@ -1,13 +1,13 @@
module Sentiment
using ZipFile
using ..Data: deps
using ..Data: deps, download_and_verify
function load()
isfile(deps("")) && return
@info "Downloading sentiment treebank dataset"
deps(""), "5c613a4f673fc74097d523a2c83f38e0cc462984d847b82c7aaf36b01cbbbfcc")
getfile(r, name) = r.files[findfirst(x -> == name, r.files)]

@ -21,8 +21,8 @@ struct Chain{T<:Tuple}
Chain(xs...) = new{typeof(xs)}(xs)
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex
@forward Chain.layers Base.iterate
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)

@ -1,4 +1,4 @@
using NNlib: conv, depthwiseconv
using NNlib: conv, ∇conv_data, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
@ -13,7 +13,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
@ -57,6 +57,54 @@ end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
struct ConvTranspose{N,F,A,V}
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose
function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
function, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
DepthwiseConv(size, in)
DepthwiseConv(size, in=>mul)
@ -83,12 +131,12 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
stride = 1, pad = 0) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,

@ -106,7 +106,7 @@ mutable struct BatchNorm{F,V,W,N}
BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
@ -138,7 +138,9 @@ function (BN::BatchNorm)(x)
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
# This is intentionally not fused because of an extreme slowdown doing so
λ.(temp .+ reshape(β, affine_shape...))

@ -84,7 +84,7 @@ end
RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(init(out)))
param(init(out)), param(zeros(out)))
function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
@ -122,8 +122,8 @@ end
function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(init(out)), param(init(out)))
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
param(zeros(out)), param(zeros(out)))[gate(out, 2)] .= 1
return cell
@ -169,7 +169,7 @@ end
GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zeros(out*3)), param(init(out)))
param(init(out*3)), param(zeros(out)))
function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1)

@ -2,16 +2,14 @@ using NNlib: logsoftmax, logσ
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
mse(, y) = sum(( .- y).^2) * 1 // length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2)
-sum(y .* log.() .* weight) * 1 // size(y, 2)
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
@ -42,12 +40,17 @@ but it is more numerically stable.
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
normalise(x::AbstractArray; dims=1)
Normalise each column of `x` to mean 0 and standard deviation 1.
Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
function normalise(x::AbstractVecOrMat)
μ′ = mean(x, dims = 1)
σ = std(x, dims = 1, mean = μ′)
function normalise(x::AbstractArray; dims=1)
μ′ = mean(x, dims = dims)
σ = std(x, dims = dims, mean = μ′, corrected=false)
return (x .- μ′) ./ σ
function normalise(x::AbstractArray, dims)
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
normalise(x, dims = dims)

View File

@ -68,3 +68,6 @@ end
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)

@ -4,7 +4,7 @@ using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps)
updaterule(opt, ps) = () -> _update_params!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
@ -117,7 +117,7 @@ struct OldOptimiser
update!(opt::OldOptimiser, ps) = opt.func()
_update_params!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())

@ -18,7 +18,7 @@ end
Descent() = Descent(0.1)
function update!(o::Descent, x, Δ)
function apply!(o::Descent, x, Δ)
Δ .*= o.eta
@ -35,7 +35,7 @@ end
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ)
function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ
@ -55,7 +55,7 @@ end
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ)
function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
d = @. ρ^2 * v - (1+ρ) * η * Δ
@ -78,7 +78,7 @@ end
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function update!(o::RMSProp, x, Δ)
function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2
@ -98,7 +98,7 @@ end
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
function update!(o::ADAM, x, Δ)
function apply!(o::ADAM, x, Δ)
η, β = o.eta, o.beta
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@ -122,7 +122,7 @@ end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ)
function apply!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@ -145,7 +145,7 @@ end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ)
function apply!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2
@ -165,7 +165,7 @@ end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ)
function apply!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@ -188,7 +188,7 @@ end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ)
function apply!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@ -211,7 +211,7 @@ end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ)
function apply!(o::NADAM, x, Δ)
η, β = o.eta, o.beta
β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x)))
@ -228,7 +228,7 @@ end
[ADAMW]( fixing weight decay regularization in Adam.
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(wd))
Optimiser(ADAM(η, β), WeightDecay(decay))
# Compose optimizers
@ -250,13 +250,21 @@ Optimiser(o...) = Optimiser(Any[o...])
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
function apply!(o::Optimiser, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
Δ = apply!(opt, x, Δ)
return Δ
Apply inverse time decay to an optimiser
Optimiser(InvDecay(..), Opt(..))
mutable struct InvDecay
@ -264,7 +272,7 @@ end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ)
function apply!(o::InvDecay, x, Δ)
γ = o.gamma
n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n)
@ -272,6 +280,16 @@ function update!(o::InvDecay, x, Δ)
return Δ
`ExpDecay(eta, decay, decay_step, clip)`
Schedule the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
To apply exponential decay to an optimiser:
Optimiser(ExpDecay(..), Opt(..))
mutable struct ExpDecay
@ -282,7 +300,7 @@ end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
@ -292,13 +310,18 @@ function update!(o::ExpDecay, x, Δ)
@. Δ *= decay
Decay the weight parameter by `wd`
mutable struct WeightDecay
WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ)
function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x

@ -1,10 +1,14 @@
using Juno
using Flux.Tracker: data, grad, back!
import Flux.Tracker: data, grad, back!, update!
import Base.depwarn
function update!(opt, xs)
function update!(opt, x, )
update!(x, apply!(opt, x, copy(data())))
function _update_params!(opt, xs)
for x in xs
Δ = update!(opt,, x.grad)
Δ = apply!(opt,, x.grad) .-= Δ
Δ .= 0
@ -45,7 +49,7 @@ function stop()
train!(model, loss, data, opt)
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
@ -54,11 +58,11 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds:
Flux.train!(model, loss, data, opt,
Flux.train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
The callback can return `:stop` to interrupt the training loop.
The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
@ -69,7 +73,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
l = loss(d...)
@interrupts back!(l)
update!(opt, ps)
_update_params!(opt, ps)
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)

View File

@ -6,7 +6,7 @@ using MacroTools: @q, @forward
import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back!
jacobian, hessian, param, back!
tracker(x) = nothing
@ -61,24 +61,20 @@ macro grad(ex)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
function update!(x, Δ) .+= data(Δ)
tracker(x).grad .= 0
return x
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`."""
the sign of the gradient applied to `x`.
hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))

@ -67,7 +67,7 @@ function back!(x, Δ; once = true)
function gradient_(f, xs...)
xs = param.(xs)
xs = param.(data.(xs))
l = f(xs...)
@ -147,8 +147,10 @@ end
back(::Grads, ::Nothing, _) = return
collectmemaybe(xs) = xs
function forward(f, ps::Params)
y = f()
y = collectmemaybe(f())
y, function (Δ)
g = Grads(ps)
if istracked(y)
@ -179,3 +181,30 @@ end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)
# Jacobians and Hessians
import ..Flux
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,k,n)
for i = 1:k
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[i,:] = xp.grad
xp.grad .= 0 # Reset gradient accumulator
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)

@ -0,0 +1,53 @@
using ForwardDiff
seed(x::Real, ::Val) = Dual(x, true)
function seed(x, ::Val{N}, offset = 0) where N
map(x, reshape(1:length(x), size(x))) do x, i
Dual(x, ntuple(j -> j+offset == i, Val(N)))
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
J = similar(xs, V, N, length(xs))
for i = 1:length(xs), j = 1:N
J[j, i] = xs[i].partials.values[j]
return map(x -> x.value, xs), J
function forward_jacobian(f, x, ::Val{N}) where N
y, _J = extract(f(seed(x, Val(N))))
J = similar(_J, length(x), length(y))
J[1:N,:] = _J
offset = 0
while offset + N < length(x)
offset += N
_, _J = extract(f(seed(x, Val(N), offset)))
range = (1+offset):min(N+offset,length(x))
J[range,:] = @view _J[range.-offset,:]
return y, J
function forward_jacobian(f, x)
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
forward_jacobian(f, x, Val(length(x)))
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
vec_scalar(x) = vec(x)
vec_scalar(x::Real) = [x]
reshape_scalar(x, y) = reshape(y, size(x))
reshape_scalar(x::Real, y) = y[]
@grad function forwarddiff(f, x)
y, J = forward_jacobian(f, data(x))
return y, -> (nothing, reshape_scalar(x, J*vec_scalar()))

@ -1,7 +1,7 @@
import Base: *
import LinearAlgebra
import LinearAlgebra: inv, \, /
import LinearAlgebra: inv, det, logdet, logabsdet, \, /
using Statistics
using LinearAlgebra: Transpose, Adjoint, diagm, diag
@ -65,6 +65,12 @@ Base.setindex!(xs::TrackedArray, v, i...) =
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
function update!(x::TrackedArray, Δ) .+= data(Δ)
tracker(x).grad .= 0
return x
# Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args
@ -115,8 +121,17 @@ Base.:-(xs::TrackedArray) = track(-, xs)
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
det(xs::TrackedArray) = track(det, xs)
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
logdet(xs::TrackedArray) = track(logdet, xs)
@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),)
logabsdet(xs::TrackedArray) = track(logabsdet, xs)
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
@ -142,11 +157,9 @@ function combinations(xs, n)
[[x, c...] for x in xs, c in cs]
combinations([AbstractArray, TrackedArray], 2)
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
@ -219,8 +232,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs,
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
@ -305,9 +321,9 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
# Hacks to get std working
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected)
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@ -357,7 +373,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs)
@ -384,8 +400,18 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@grad conv(x, w; kw...) =
conv(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
(NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...),
NNlib.∇conv_filter(data.((Δ, x))...; size=size(w), kw...)))
∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)
@grad ∇conv_data(x, w; kw...) =
∇conv_data(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.conv(data.((Δ, w))...; size=size(x), kw...),
NNlib.∇conv_filter(data.((x, Δ))...; size=size(w), kw...)))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)

@ -1,4 +1,4 @@
struct TrackedReal{T<:Real} <: Real
mutable struct TrackedReal{T<:Real} <: Real
@ -16,6 +16,12 @@ function back!(x::TrackedReal; once = true)
return back!(x, 1, once = once)
function update!(x::TrackedReal, Δ) += data(Δ)
tracker(x).grad = 0
return x
function, x::TrackedReal)
T = get(io, :typeinfo, Any)
show(io, data(x))
@ -33,6 +39,8 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
(T::Type{<:TrackedReal})(x::Real) = convert(T, x)
for op in [:(==), :≈, :<, :(<=)]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@ -53,6 +61,12 @@ Base.float(x::TrackedReal) = x
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
using Random
for f in :[rand, randn, randexp].args
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
@ -87,6 +101,13 @@ import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b)
# Hack for conversions
using ForwardDiff: Dual
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
(Dual{T,V,N})(x::Dual) where {T,V,N} = invoke(Dual{T,V,N}, Tuple{Number}, x)
# Tuples
struct TrackedTuple{T<:Tuple}
@ -134,3 +155,6 @@ end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)

@ -1,11 +1,13 @@
import Adapt: adapt
import Adapt: adapt, adapt_storage
import .Tracker: IdSet
children(x) = ()
mapchildren(f, x) = x
children(x::Tuple) = x
children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin
@ -14,11 +16,6 @@ function treelike(m::Module, T, fs = fieldnames(T))
function treelike(T, fs = fieldnames(T))
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
treelike(Base._current_module(), T, fs)
macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
@ -69,3 +66,22 @@ gpu_adaptor = identity
gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
Tracker.istracked(x) ? param(f( :
x isa Union{AbstractArray,Number} ? f(x) :

@ -10,8 +10,8 @@ zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(unsqueeze.(xs, dim)...; dims=dim)
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
chunk(xs, n)
@ -139,25 +139,6 @@ function throttle(f, timeout; leading=true, trailing=false)
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .= 0 # Reset gradient accumulator
@jit ...

View File

@ -11,6 +11,8 @@ x = param(randn(5, 5))
cx = gpu(x)
@test cx isa TrackedArray && isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x)
@test cx isa Flux.OneHotMatrix && isa CuArray

@ -21,3 +21,15 @@ end
@test size(m(r)) == (10, 5)
@testset "Depthwise Conv" begin
r = zeros(Float32, 28, 28, 3, 5)
m1 = DepthwiseConv((2, 2), 3=>5)
@test size(m1(r), 3) == 15
m2 = DepthwiseConv((2, 2), 3)
@test size(m2(r), 3) == 3

@ -98,4 +98,9 @@ end
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
@test (@allocated m(x)) < 100_000_000

@ -49,4 +49,16 @@ const ϵ = 1e-7
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
@testset "no spurious promotions" begin
for T in (Float16, Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.Tracker.forward(mse, , y)
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}

@ -4,7 +4,7 @@ using Flux.Tracker
using Test
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
@ -17,7 +17,7 @@ using Test
for t = 1: 10^5
l = loss(rand(10))
delta = Optimise.update!(opt,, w.grad)
delta = Optimise.apply!(opt,, w.grad) .-= delta
@test Flux.mse(w, w) < 0.01
@ -33,7 +33,7 @@ end
for t = 1:10^5
l = loss(rand(10))
delta = Optimise.update!(opt,, w.grad)
delta = Optimise.apply!(opt,, w.grad) .-= delta
@test Flux.mse(w, w) < 0.01

@ -1,18 +1,3 @@
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
if Base.JLOptions().check_bounds == 1
file = @__FILE__
--color=$(Base.have_color ? "yes" : "no")
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
using Flux, Test, Random, Statistics
using Random

@ -1,9 +1,9 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
using NNlib: conv, depthwiseconv
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
using NNlib: conv, ∇conv_data, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
using Statistics: mean, std
using Random
# using StatsBase
@ -34,6 +34,10 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@test gradtest(x -> x', rand(5))
@test gradtest(det, (4, 4))
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
@ -113,9 +117,17 @@ end
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
@testset "scalars" begin
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
@test vcat(1, param([1, 2, 3])) isa TrackedArray
@test hcat(1, param([1 2 3;])) isa TrackedArray
@test vcat(param(1), 2) isa TrackedArray
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@ -166,6 +178,10 @@ end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, dims = 1), rand(5,5))
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
@test gradtest(x -> Flux.normalise(x), rand(4,3))
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))
@ -177,18 +193,28 @@ end
2y + x
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(conv, rand(10, 3, 2), randn(Float64, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2))
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test gradtest(x -> Float64.(x), 5)
@testset "equality & order" begin
# TrackedReal
@test param(2)^2 == param(4)
@ -260,7 +286,7 @@ Tracker.back!(b)
@test grad.((x,y)) == (3, 2)
@test Tracker.gradient(2, 3) do x, y
@test gradient(2, 3) do x, y
xy = Tracker.collect([x, y])
end == (3, 2)
@ -286,4 +312,36 @@ end
@test count == 3
@testset "Updates" begin
xs = param([1, 2, 3])
Tracker.update!(xs, param([4, 5, 6]))
@test xs == [5, 7, 9]
x = param(3)
Tracker.update!(x, param(4))
@test x == 7
@testset "Params" begin
W = param(randn(5, 10))
x = rand(10)
dW = gradient(W -> sum(W*x), W)[1]
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
@test gs[W] == dW
@testset "Forward" begin
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
reshape(ones(25), :, 1)
@test gradient([2, 3]) do x
forwarddiff(x) do x
end == ([3, 2],)
@testset "Custom Sensitivities" begin
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
@test back([1, 1]) == (32,)
end #testset

@ -1,5 +1,5 @@
using Flux
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack
using StatsBase: std
using Random
using Test
@ -87,8 +87,27 @@ end
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
@testset "Basic" begin
@testset "Basic Stacking" begin
x = randn(3,3)
stacked = stack([x, x], 2)
@test size(stacked) == (3,2,3)
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
@test eltype(m[1] == Float32
@test eltype(m(x).data) == Float32
@test eltype(f64(m)(x).data) == Float64
@test eltype(f64(m)[1] == Float64
@test eltype(f32(f64(m))[1] == Float32
@test Tracker.isleaf(f32(f64(m))[1].W)
@testset "Stacking" begin
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unstack(stacked_array, 2) == unstacked_array
@test stack(unstacked_array, 2) == stacked_array
@test stack(unstack(stacked_array, 1), 1) == stacked_array