Merge remote-tracking branch 'upstream/master' into drop_shape

This commit is contained in:
chengchingwen 2019-05-14 00:50:59 +08:00
commit 2fc2a5282c
56 changed files with 1657 additions and 1651 deletions

37
.gitlab-ci.yml Normal file
View File

@ -0,0 +1,37 @@
before_script:
- export CI_DISABLE_CURNN_TEST=true
variables:
CI_IMAGE_TAG: 'cuda'
include:
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml'
.flux:
extends: .test
script:
- julia -e 'using InteractiveUtils;
versioninfo()'
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
- julia -e 'using Pkg;
Pkg.add("CuArrays");'
- julia --project -e 'using Pkg;
Pkg.instantiate();
Pkg.build();
Pkg.test(; coverage=true);'
test:v1.0:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.0'
only:
- staging
- trying
test:v1.1:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.1'
only:
- staging
- trying

View File

@ -23,3 +23,7 @@ jobs:
Pkg.instantiate()' Pkg.instantiate()'
- julia --project=docs/ docs/make.jl - julia --project=docs/ docs/make.jl
after_success: skip after_success: skip
## uncomment the following lines to override the default test script
script:
- julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()'

29
CITATION.bib Normal file
View File

@ -0,0 +1,29 @@
@article{Flux.jl-2018,
author = {Michael Innes and
Elliot Saba and
Keno Fischer and
Dhairya Gandhi and
Marco Concetto Rudilosso and
Neethu Mariya Joy and
Tejan Karmali and
Avik Pal and
Viral Shah},
title = {Fashionable Modelling with Flux},
journal = {CoRR},
volume = {abs/1811.01457},
year = {2018},
url = {http://arxiv.org/abs/1811.01457},
archivePrefix = {arXiv},
eprint = {1811.01457},
timestamp = {Thu, 22 Nov 2018 17:58:30 +0100},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1811-01457},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{innes:2018,
author = {Mike Innes},
title = {Flux: Elegant Machine Learning with Julia},
journal = {Journal of Open Source Software},
year = {2018},
doi = {10.21105/joss.00602},
}

View File

@ -1,6 +1,6 @@
The Flux.jl package is licensed under the MIT "Expat" License: The Flux.jl package is licensed under the MIT "Expat" License:
> Copyright (c) 2016: Mike Innes. > Copyright (c) 2016-19: Julia Computing, INc., Mike Innes and Contributors
> >
> Permission is hereby granted, free of charge, to any person obtaining > Permission is hereby granted, free of charge, to any person obtaining
> a copy of this software and associated documentation files (the > a copy of this software and associated documentation files (the

View File

@ -27,11 +27,17 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3" version = "0.5.3"
[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
[[CodecZlib]] [[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9" git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193" uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1" version = "0.5.2"
[[ColorTypes]] [[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"] deps = ["FixedPointNumbers", "Random", "Test"]
@ -53,9 +59,15 @@ version = "0.2.0"
[[Compat]] [[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a" git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0" version = "2.1.0"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"
[[DataStructures]] [[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
@ -73,15 +85,15 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]] [[DiffResults]]
deps = ["Compat", "StaticArrays"] deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7" git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3" version = "0.0.4"
[[DiffRules]] [[DiffRules]]
deps = ["Random", "Test"] deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad" git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7" version = "0.0.10"
[[Distributed]] [[Distributed]]
deps = ["Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
@ -95,9 +107,9 @@ version = "0.5.3"
[[ForwardDiff]] [[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de" git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210" uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.2" version = "0.10.3"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["Markdown"] deps = ["Markdown"]
@ -105,9 +117,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]] [[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"] deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8" git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.4" version = "0.7.0"
[[LibGit2]] [[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -123,10 +135,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]] [[MacroTools]]
deps = ["Compat"] deps = ["CSTParser", "Compat", "DataStructures", "Test"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b" git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4" version = "0.5.0"
[[Markdown]] [[Markdown]]
deps = ["Base64"] deps = ["Base64"]
@ -148,10 +160,10 @@ version = "0.4.0"
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]] [[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d" git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3" version = "0.6.0"
[[NaNMath]] [[NaNMath]]
deps = ["Compat"] deps = ["Compat"]
@ -161,9 +173,9 @@ version = "0.3.2"
[[OrderedCollections]] [[OrderedCollections]]
deps = ["Random", "Serialization", "Test"] deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b" git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2" version = "1.1.0"
[[Pkg]] [[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@ -228,29 +240,47 @@ version = "0.7.2"
[[StaticArrays]] [[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898" git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
uuid = "90137ffa-7385-5640-81b9-e52037218182" uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2" version = "0.10.3"
[[Statistics]] [[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"] deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]] [[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598" git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0" version = "0.30.0"
[[Test]] [[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.0"
[[TranscodingStreams]] [[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"] deps = ["Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec" git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1" version = "0.9.4"
[[URIParser]] [[URIParser]]
deps = ["Test", "Unicode"] deps = ["Test", "Unicode"]
@ -267,6 +297,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]] [[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"] deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac" git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0" version = "0.8.1"

25
NEWS.md Normal file
View File

@ -0,0 +1,25 @@
# v0.8.0
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
* New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647)
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/).
* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592)
* Many docs and bugfixes thanks to @KristofferC and others.
* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`.
* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615).
* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs.
* New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656).
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
AD Changes:
* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files).
* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576).
* Flux.Tracker is now its [own package](https://github.com/FluxML/Tracker.jl), in preparation for replacing it with Zygote.
# v0.7.0
Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.

View File

@ -1,25 +1,35 @@
name = "Flux" name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.8.3"
[deps] [deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df" Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
[compat]
NNlib = "0.6"
Tracker = "0.2"
julia = "0.7, 1"
[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
test = ["Test"]

View File

@ -2,7 +2,7 @@
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/> <img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
</p> </p>
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](http://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602) [![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](https://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602)
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
@ -10,7 +10,7 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, a
julia> Pkg.add("Flux") julia> Pkg.add("Flux")
``` ```
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. See the [documentation](https://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
If you use Flux in research, please cite the following paper: If you use Flux in research, please cite the following paper:

View File

@ -10,9 +10,4 @@ ZipFile
AbstractTrees AbstractTrees
Reexport Reexport
StatsBase StatsBase
Tracker
# AD
ForwardDiff 0.5.0
DiffRules
SpecialFunctions
NaNMath

4
bors.toml Normal file
View File

@ -0,0 +1,4 @@
status = [
"ci/gitlab/%"
]
timeout-sec = 14400

View File

@ -1,3 +1,5 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractTrees]] [[AbstractTrees]]
deps = ["Markdown", "Test"] deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
@ -6,9 +8,9 @@ version = "0.2.1"
[[Adapt]] [[Adapt]]
deps = ["LinearAlgebra", "Test"] deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06" git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1" version = "0.4.2"
[[Base64]] [[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -25,11 +27,17 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3" version = "0.5.3"
[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
[[CodecZlib]] [[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9" git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193" uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1" version = "0.5.2"
[[ColorTypes]] [[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"] deps = ["FixedPointNumbers", "Random", "Test"]
@ -51,9 +59,15 @@ version = "0.2.0"
[[Compat]] [[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a" git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0" version = "2.1.0"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"
[[DataStructures]] [[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
@ -71,31 +85,31 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]] [[DiffResults]]
deps = ["Compat", "StaticArrays"] deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7" git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3" version = "0.0.4"
[[DiffRules]] [[DiffRules]]
deps = ["Random", "Test"] deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad" git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7" version = "0.0.10"
[[Distributed]] [[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]] [[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"] deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7" git-tree-sha1 = "4d30e889c9f106a51ffa4791a88ffd4765bf20c3"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.6.0" version = "0.7.0"
[[Documenter]] [[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"] deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617" git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.21.0" version = "0.22.3"
[[FixedPointNumbers]] [[FixedPointNumbers]]
deps = ["Test"] deps = ["Test"]
@ -104,26 +118,32 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3" version = "0.5.3"
[[Flux]] [[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"] deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"]
path = ".." path = ".."
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.6.10+" version = "0.8.2+"
[[ForwardDiff]] [[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e" git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210" uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1" version = "0.10.3"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[JSON]]
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.20.0"
[[Juno]] [[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"] deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658" git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3" version = "0.7.0"
[[LibGit2]] [[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -139,10 +159,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]] [[MacroTools]]
deps = ["Compat"] deps = ["CSTParser", "Compat", "DataStructures", "Test"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b" git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4" version = "0.5.0"
[[Markdown]] [[Markdown]]
deps = ["Base64"] deps = ["Base64"]
@ -156,18 +176,18 @@ version = "0.5.0"
[[Missings]] [[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233" git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1" version = "0.4.0"
[[Mmap]] [[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]] [[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d" git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3" version = "0.6.0"
[[NaNMath]] [[NaNMath]]
deps = ["Compat"] deps = ["Compat"]
@ -177,9 +197,9 @@ version = "0.3.2"
[[OrderedCollections]] [[OrderedCollections]]
deps = ["Random", "Serialization", "Test"] deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b" git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2" version = "1.1.0"
[[Pkg]] [[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@ -244,29 +264,47 @@ version = "0.7.2"
[[StaticArrays]] [[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898" git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
uuid = "90137ffa-7385-5640-81b9-e52037218182" uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2" version = "0.10.3"
[[Statistics]] [[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"] deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]] [[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598" git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0" version = "0.30.0"
[[Test]] [[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.0"
[[TranscodingStreams]] [[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"] deps = ["Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec" git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1" version = "0.9.4"
[[URIParser]] [[URIParser]]
deps = ["Test", "Unicode"] deps = ["Test", "Unicode"]
@ -275,7 +313,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0" version = "0.4.0"
[[UUIDs]] [[UUIDs]]
deps = ["Random"] deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]] [[Unicode]]
@ -283,6 +321,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]] [[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"] deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac" git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0" version = "0.8.1"

View File

@ -1,7 +1,7 @@
using Documenter, Flux, NNlib using Documenter, Flux, NNlib
makedocs(modules=[Flux, NNlib], makedocs(modules=[Flux, NNlib],
doctest = false, doctest = true,
analytics = "UA-36890222-9", analytics = "UA-36890222-9",
sitename = "Flux", sitename = "Flux",
# Uncomment below for local build # Uncomment below for local build
@ -19,6 +19,7 @@ makedocs(modules=[Flux, NNlib],
"One-Hot Encoding" => "data/onehot.md", "One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md", "GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md", "Saving & Loading" => "saving.md",
"Performance Tips" => "performance.md",
"Internals" => "Internals" =>
["Backpropagation" => "internals/tracker.md"], ["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"]) "Community" => "community.md"])

View File

@ -1,5 +1,17 @@
# GPU Support # GPU Support
## Installation
To get GPU support for NVIDIA graphics cards, you need to install `CuArrays.jl`
**Steps needed**
1. Install [NVIDIA toolkit](https://developer.nvidia.com/cuda-downloads)
2. Install [NVIDIA cuDNN library](https://developer.nvidia.com/cudnn)
3. In Julia's terminal run `]add CuArrays`
## GPU Usage
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it. Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.

View File

@ -4,49 +4,56 @@
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.) Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
```julia ```jldoctest basics
using Flux.Tracker julia> using Flux.Tracker
f(x) = 3x^2 + 2x + 1 julia> f(x) = 3x^2 + 2x + 1;
# df/dx = 6x + 2 julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
df(x) = Tracker.gradient(f, x)[1]
df(2) # 14.0 (tracked) julia> df(2)
14.0 (tracked)
# d²f/dx² = 6 julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
d2f(x) = Tracker.gradient(df, x)[1]
d2f(2) # 6.0 (tracked) julia> d2f(2)
6.0 (tracked)
``` ```
(We'll learn more about why these numbers show up as `(tracked)` below.) (We'll learn more about why these numbers show up as `(tracked)` below.)
When a function has many parameters, we can pass them all in explicitly: When a function has many parameters, we can pass them all in explicitly:
```julia ```jldoctest basics
f(W, b, x) = W * x + b julia> f(W, b, x) = W * x + b;
Tracker.gradient(f, 2, 3, 4) julia> Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0, 2.0 (tracked)) (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
``` ```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once. But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
```julia ```jldoctest basics
W = param(2) # 2.0 (tracked) julia> using Flux
b = param(3) # 3.0 (tracked)
f(x) = W * x + b julia> W = param(2)
2.0 (tracked)
params = Params([W, b]) julia> b = param(3)
grads = Tracker.gradient(() -> f(4), params) 3.0 (tracked)
grads[W] # 4.0 julia> f(x) = W * x + b;
grads[b] # 1.0
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
julia> grads[W]
4.0 (tracked)
julia> grads[b]
1.0 (tracked)
``` ```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate. There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple. This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
@ -77,7 +84,7 @@ using Flux.Tracker
W = param(W) W = param(W)
b = param(b) b = param(b)
gs = Tracker.gradient(() -> loss(x, y), Params([W, b])) gs = Tracker.gradient(() -> loss(x, y), params(W, b))
``` ```
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent. Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
@ -102,6 +109,8 @@ All deep learning in Flux, however complex, is a simple generalisation of this e
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as: It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as:
```julia ```julia
using Flux
W1 = param(rand(3, 5)) W1 = param(rand(3, 5))
b1 = param(rand(3)) b1 = param(rand(3))
layer1(x) = W1 * x .+ b1 layer1(x) = W1 * x .+ b1

View File

@ -5,15 +5,18 @@ These core layers form the foundation of almost all neural networks.
```@docs ```@docs
Chain Chain
Dense Dense
```
## Convolution and Pooling Layers
These layers are used to build convolutional neural networks (CNNs).
```@docs
Conv Conv
MaxPool MaxPool
MeanPool MeanPool
```
## Additional Convolution Layers
```@docs
DepthwiseConv DepthwiseConv
ConvTranspose
``` ```
## Recurrent Layers ## Recurrent Layers
@ -27,6 +30,14 @@ GRU
Flux.Recur Flux.Recur
``` ```
## Other General Purpose Layers
These are marginally more obscure than the Basic Layers.
But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs).
```@docs
Maxout
```
## Activation Functions ## Activation Functions
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux. Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
@ -49,5 +60,7 @@ These layers don't affect the structure of the network but may improve training
Flux.testmode! Flux.testmode!
BatchNorm BatchNorm
Dropout Dropout
AlphaDropout
LayerNorm LayerNorm
GroupNorm
``` ```

View File

@ -77,7 +77,7 @@ If you use the `RNN(10, 5)` constructor as opposed to `RNNCell` you'll s
```julia ```julia
julia> RNN(10, 5) julia> RNN(10, 5)
Recur(RNNCell(Dense(15, 5))) Recur(RNNCell(10, 5, tanh))
``` ```
## Sequences ## Sequences
@ -114,3 +114,13 @@ truncate!(m)
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation. Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you. `truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
In general, when training with recurrent layers in your model, you'll want to call `reset!` or `truncate!` for each loss calculation:
```julia
function loss(x,y)
l = Flux.mse(m(x), y)
Flux.reset!(m)
return l
end
```

76
docs/src/performance.md Normal file
View File

@ -0,0 +1,76 @@
# Performance Tips
All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/manual/performance-tips/).
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
Below follow some Flux specific tips/reminders.
## Don't use more precision than you need.
Flux works great with all kinds of number types.
But often you do not need to be working with say `Float64` (let alone `BigFloat`).
Switching to `Float32` can give you a significant speed up,
not because the operations are faster, but because the memory usage is halved.
Which means allocations occur much faster.
And you use less memory.
## Make sure your custom activation functions preserve the type of their inputs
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
they should also preserve the type of their inputs.
A very artificial example using an activatioon function like
```
my_tanh(x) = Float64(tanh(x))
```
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
because it results in having to use slow mixed type multiplication in the dense layers.
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
you will see a large slow-down
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
E.g. the following will have run into the same problem as above:
```
leaky_tanh(x) = 0.01x + tanh(x)
```
While one could change your activation function (e.g. to use `0.01f0x`) to avoid this when ever your inputs change,
the idiomatic (and safe way) is to use `oftype`.
```
leaky_tanh(x) = oftype(x/1, 0.01) + tanh(x)
```
## Evaluate batches as Matrices of features, rather than sequences of Vector features
While it can sometimes be tempting to process your observations (feature vectors) one at a time
e.g.
```julia
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
sum(zip(xs, ys)) do (x, y_target)
y_pred = model(x) # evaluate the model
return loss(y_pred, y_target)
end
end
```
It is much faster to concatenate them into a matrix,
as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications.
Even though this means allocating new memory to store them contiguously.
```julia
x_batch = reduce(hcat, xs)
y_batch = reduce(hcat, ys)
...
function loss_total(x_batch::Matrix, y_batch::Matrix)
y_preds = model(x_batch)
sum(loss.(y_preds, y_batch))
end
```
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
This will avoid the splatting penality, and will hit the optimised `reduce` method.

View File

@ -3,7 +3,7 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
```julia ```julia
using Flux.Tracker using Flux, Flux.Tracker
W = param(rand(2, 5)) W = param(rand(2, 5))
b = param(rand(2)) b = param(rand(2))
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3 l = loss(x, y) # ~ 3
params = Params([W, b]) θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params) grads = Tracker.gradient(() -> loss(x, y), θ)
``` ```
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down.
opt = Descent(0.1) # Gradient descent with learning rate 0.1 opt = Descent(0.1) # Gradient descent with learning rate 0.1
for p in (W, b) for p in (W, b)
update!(opt, p, -η * grads[p]) update!(opt, p, grads[p])
end end
``` ```
@ -49,5 +49,12 @@ All optimisers return an object that, when passed to `train!`, will update the p
Descent Descent
Momentum Momentum
Nesterov Nesterov
RMSProp
ADAM ADAM
AdaMax
ADAGrad
ADADelta
AMSGrad
NADAM
ADAMW
``` ```

View File

@ -93,3 +93,11 @@ evalcb() = @show(loss(test_x, test_y))
Flux.train!(objective, ps, data, opt, Flux.train!(objective, ps, data, opt,
cb = throttle(evalcb, 5)) cb = throttle(evalcb, 5))
``` ```
Calling `Flux.stop()` in a callback will exit the training loop early.
```julia
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```

View File

@ -14,7 +14,7 @@
journal = {arXiv}, journal = {arXiv},
volume = {abs/11712.03112}, volume = {abs/11712.03112},
year = {2017}, year = {2017},
url = {http://arxiv.org/abs/1712.03112}, url = {https://arxiv.org/abs/1712.03112},
} }
@online{MLPL, @online{MLPL,
@ -29,7 +29,7 @@
author = {Mike Innes and others}, author = {Mike Innes and others},
title = {Generic GPU Kernels}, title = {Generic GPU Kernels},
year = 2017, year = 2017,
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html}, url = {https://mikeinnes.github.io/2017/08/24/cudanative.html},
urldate = {2018-02-16} urldate = {2018-02-16}
} }

View File

@ -6,15 +6,14 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib @reexport using NNlib
include("tracker/Tracker.jl") using Tracker
using .Tracker using Tracker: data
using .Tracker: data
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")

View File

@ -1,17 +1,18 @@
module CUDA module CUDA
using ..CuArrays using ..CuArrays
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
using Pkg.TOML using Pkg.TOML
function version_check() function version_check()
minor_version = 9 major_version = 1
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml") project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project))) project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0")) version = VersionNumber(get(project, "version", "0.0.0"))
if !(version.major == 0 && version.minor == minor_version) if version.major != major_version
@warn """ @warn """
Flux is only supported with CuArrays v0.$minor_version. Flux is only supported with CuArrays v$major_version.x.
Try running `] pin CuArrays@0.$minor_version`. Try running `] pin CuArrays@$major_version`.
""" """
end end
end end

View File

@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s) @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0? states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states) desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong), @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
desc,handle(),ρ,states,length(states),seed) desc,handle(),ρ,states,length(states),seed)
finalizer(desc) do x finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
mean = zeros(CuArray{T}, dims...) mean = zeros(CuArray{T}, dims...)
ivar = ones(CuArray{T}, dims...) ivar = ones(CuArray{T}, dims...)
else else
mean = C_NULL mean = CU_NULL
ivar = C_NULL ivar = CU_NULL
end end
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t, (cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, CuPtr{T},
Cdouble, Ptr{T}, Ptr{T}, Cdouble, CuPtr{T}, CuPtr{T},
Cdouble, Ptr{T}, Ptr{T}), Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL, handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
xd, x, xd, x,
@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t, (Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, CuPtr{T},
Ptr{T}, Ptr{T}, CuPtr{T}, CuPtr{T},
Cdouble), Cdouble),
handle(), BATCHNORM_SPATIAL, handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
mean, ivar = cache.mean, cache.ivar mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache") info("mean and ivar are fetched from the cache")
else else
mean, ivar = C_NULL, C_NULL mean, ivar = CU_NULL, CU_NULL
end end
if eps < BATCHNORM_MIN_EPS if eps < BATCHNORM_MIN_EPS
@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
(cudnnHandle_t,cudnnBatchNormMode_t, (cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
Cdouble, Ptr{T}, Ptr{T}), Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL, handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)), Ref(T(dalpha)), Ref(T(dbeta)),
@ -194,7 +194,7 @@ end
# Flux Interface # Flux Interface
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = (BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active) BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =

View File

@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
if reserve == nothing if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Csize_t), CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace)) workspace, length(workspace))
else else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve)) workspace, length(workspace), reserve, length(reserve))
@ -121,7 +121,7 @@ end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, C_NULL hDesc(h::Nothing) = C_NULL, CU_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray) function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end end
@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
workspace, reserve) where T workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x Ptr{Ptr{Nothing}}, CuPtr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx Ptr{Nothing}, CuPtr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y Ptr{Ptr{Nothing}}, CuPtr{T}, #y
Ptr{Nothing}, Csize_t, #ws CuPtr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw Ptr{Nothing}, CuPtr{T}, #dw
Ptr{Nothing}, Csize_t), #rs CuPtr{Nothing}, Csize_t), #rs
handle(), rnn, seqlen, xd, x, hd, h, yd, y, handle(), rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve)) workspace, length(workspace), dwd, dw, reserve, length(reserve))
end end

View File

@ -1,11 +1,27 @@
module Data module Data
import ..Flux import ..Flux
import SHA
export CMUDict, cmudict export CMUDict, cmudict
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
function download_and_verify(url, path, hash)
tmppath = tempname()
download(url, tmppath)
hash_download = open(tmppath) do f
bytes2hex(SHA.sha256(f))
end
if hash_download !== hash
msg = "Hash Mismatch!\n"
msg *= " Expected sha256: $hash\n"
msg *= " Calculated sha256: $hash_download"
error(msg)
end
mv(tmppath, path; force=true)
end
function __init__() function __init__()
mkpath(deps()) mkpath(deps())
end end
@ -23,4 +39,7 @@ include("tree.jl")
include("sentiment.jl") include("sentiment.jl")
using .Sentiment using .Sentiment
include("iris.jl")
export Iris
end end

View File

@ -2,23 +2,25 @@ module CMUDict
export cmudict export cmudict
using ..Data: deps using ..Data: deps, download_and_verify
const version = "0.7b" const version = "0.7b"
const cache_prefix = "https://cache.julialang.org" const cache_prefix = "https://cache.julialang.org"
function load() function load()
suffixes = ["", ".phones", ".symbols"] suffixes_and_hashes = [("" , "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"),
(".phones" , "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"),
(".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027")]
if isdir(deps("cmudict")) if isdir(deps("cmudict"))
if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes) if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes)
return return
end end
end end
@info "Downloading CMUDict dataset" @info "Downloading CMUDict dataset"
mkpath(deps("cmudict")) mkpath(deps("cmudict"))
for x in suffixes for (x, hash) in suffixes_and_hashes
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", download_and_verify("$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
deps("cmudict", "cmudict$x")) deps("cmudict", "cmudict$x"), hash)
end end
end end

View File

@ -1,19 +1,20 @@
module FashionMNIST module FashionMNIST
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
using ..Data: download_and_verify
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist") const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
function load() function load()
mkpath(dir) mkpath(dir)
cd(dir) do cd(dir) do
for file in ["train-images-idx3-ubyte", for (file, hash) in [("train-images-idx3-ubyte", "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"),
"train-labels-idx1-ubyte", ("train-labels-idx1-ubyte", "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"),
"t10k-images-idx3-ubyte", ("t10k-images-idx3-ubyte" , "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"),
"t10k-labels-idx1-ubyte"] ("t10k-labels-idx1-ubyte" , "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5")]
isfile(file) && continue isfile(file) && continue
@info "Downloading Fashion-MNIST dataset" @info "Downloading Fashion-MNIST dataset"
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz") download_and_verify("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz", hash)
open(file, "w") do io open(file, "w") do io
write(io, gzopen(read, "$file.gz")) write(io, gzopen(read, "$file.gz"))
end end

86
src/data/iris.jl Normal file
View File

@ -0,0 +1,86 @@
"""
Iris
Fisher's classic iris dataset.
Measurements from 3 different species of iris: setosa, versicolor and
virginica. There are 50 examples of each species.
There are 4 measurements for each example: sepal length, sepal width, petal
length and petal width. The measurements are in centimeters.
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
"""
module Iris
using DelimitedFiles
using ..Data: deps, download_and_verify
# Uncomment if the iris.data file is cached to cache.julialang.org.
const cache_prefix = "https://cache.julialang.org/"
function load()
isfile(deps("iris.data")) && return
@info "Downloading iris dataset."
download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
deps("iris.data"),
"6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0")
end
"""
labels()
Get the labels of the iris dataset, a 150 element array of strings listing the
species of each example.
```jldoctest
julia> labels = Flux.Data.Iris.labels();
julia> summary(labels)
"150-element Array{String,1}"
julia> labels[1]
"Iris-setosa"
```
"""
function labels()
load()
iris = readdlm(deps("iris.data"), ',')
Vector{String}(iris[1:end, end])
end
"""
features()
Get the features of the iris dataset. This is a 4x150 matrix of Float64
elements. It has a row for each feature (sepal length, sepal width,
petal length, petal width) and a column for each example.
```jldoctest
julia> features = Flux.Data.Iris.features();
julia> summary(features)
"4×150 Array{Float64,2}"
julia> features[:, 1]
4-element Array{Float64,1}:
5.1
3.5
1.4
0.2
```
"""
function features()
load()
iris = readdlm(deps("iris.data"), ',')
Matrix{Float64}(iris[1:end, 1:4]')
end
end

View File

@ -1,6 +1,7 @@
module MNIST module MNIST
using CodecZlib, Colors using CodecZlib, Colors
using ..Data: download_and_verify
const Gray = Colors.Gray{Colors.N0f8} const Gray = Colors.Gray{Colors.N0f8}
@ -15,13 +16,13 @@ end
function load() function load()
mkpath(dir) mkpath(dir)
cd(dir) do cd(dir) do
for file in ["train-images-idx3-ubyte", for (file, hash) in [("train-images-idx3-ubyte", "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"),
"train-labels-idx1-ubyte", ("train-labels-idx1-ubyte", "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"),
"t10k-images-idx3-ubyte", ("t10k-images-idx3-ubyte" , "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"),
"t10k-labels-idx1-ubyte"] ("t10k-labels-idx1-ubyte" , "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6")]
isfile(file) && continue isfile(file) && continue
@info "Downloading MNIST dataset" @info "Downloading MNIST dataset"
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz") download_and_verify("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz", hash)
open(file, "w") do io open(file, "w") do io
write(io, gzopen(read, "$file.gz")) write(io, gzopen(read, "$file.gz"))
end end

View File

@ -1,13 +1,13 @@
module Sentiment module Sentiment
using ZipFile using ZipFile
using ..Data: deps using ..Data: deps, download_and_verify
function load() function load()
isfile(deps("sentiment.zip")) && return isfile(deps("sentiment.zip")) && return
@info "Downloading sentiment treebank dataset" @info "Downloading sentiment treebank dataset"
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", download_and_verify("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
deps("sentiment.zip")) deps("sentiment.zip"), "5c613a4f673fc74097d523a2c83f38e0cc462984d847b82c7aaf36b01cbbbfcc")
end end
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)] getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]

View File

@ -40,7 +40,24 @@ function Base.show(io::IO, c::Chain)
print(io, ")") print(io, ")")
end end
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
# This is a temporary and naive implementation
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
# Johnny Chen -- @johnnychen94
"""
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
rst = []
for l in c
x = get(rst, length(rst), input)
push!(rst, l(x))
end
return rst
end
""" """
Dense(in::Integer, out::Integer, σ = identity) Dense(in::Integer, out::Integer, σ = identity)
@ -88,6 +105,14 @@ function Base.show(io::IO, l::Dense)
print(io, ")") print(io, ")")
end end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
""" """
Diagonal(in::Integer) Diagonal(in::Integer)
@ -117,10 +142,50 @@ function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")") print(io, "Diagonal(", length(l.α), ")")
end end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = """
a(T.(x)) Maxout(over)
`Maxout` is a neural network layer, which has a number of internal layers,
which all have the same input, and the maxout returns the elementwise maximium
of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem.
Reference:
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
2013. Maxout networks.
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
struct Maxout{FS<:Tuple}
over::FS
end
"""
Maxout(f, n_alts)
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
The function takes no arguement and should return some callable layer.
Conventionally this is a linear dense layer.
For example the following example which
will construct a `Maxout` layer over 4 internal dense linear layers,
each identical in structure (784 inputs, 128 outputs).
```julia
insize = 784
outsize = 128
Maxout(()->Dense(insize, outsize), 4)
```
"""
function Maxout(f, n_alts)
over = Tuple(f() for _ in 1:n_alts)
return Maxout(over)
end
@treelike Maxout
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

View File

@ -1,10 +1,7 @@
using NNlib: conv, depthwiseconv using NNlib: conv, ∇conv_data, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
expand(N, i::Tuple) = i expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N) expand(N, i::Integer) = ntuple(_ -> i, N)
""" """
Conv(size, in=>out) Conv(size, in=>out)
Conv(size, in=>out, relu) Conv(size, in=>out, relu)
@ -12,23 +9,36 @@ expand(N, i::Integer) = ntuple(_ -> i, N)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`. Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
struct Conv{N,F,A,V} struct Conv{N,M,F,A,V}
σ::F σ::F
weight::A weight::A
bias::V bias::V
stride::NTuple{N,Int} stride::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{M,Int}
dilation::NTuple{N,Int} dilation::NTuple{N,Int}
end end
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} = stride = 1, pad = 0, dilation = 1) where {T,N}
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return Conv(σ, w, b, stride, pad, dilation)
end
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
@ -41,7 +51,8 @@ function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :( # TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(conv(x, c.weight, cdims) .+ b)
end end
function Base.show(io::IO, l::Conv) function Base.show(io::IO, l::Conv)
@ -57,6 +68,73 @@ end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = (a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x)) a(T.(x))
"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
function ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return ConvTranspose(σ, w, b, stride, pad, dilation)
end
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
C_in = size(c.weight)[end-1]
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
stride=c.stride,
padding=c.pad,
dilation=c.dilation,
)
end
function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = conv_transpose_dims(c, x)
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
""" """
DepthwiseConv(size, in) DepthwiseConv(size, in)
DepthwiseConv(size, in=>mul) DepthwiseConv(size, in=>mul)
@ -71,26 +149,32 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad` and `stride`.
""" """
struct DepthwiseConv{N,F,A,V} struct DepthwiseConv{N,M,F,A,V}
σ::F σ::F
weight::A weight::A
bias::V bias::V
stride::NTuple{N,Int} stride::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end end
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} = stride = 1, pad = 0, dilation = 1) where {T,N}
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...) stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return DepthwiseConv(σ, w, b, stride, pad, dilation)
end
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn, DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
stride = 1, pad = 0) where N = stride = 1, pad = 0, dilation = 1) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ, DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad) stride = stride, pad = pad, dilation=dilation)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
stride::NTuple{N,Integer} = map(_->1,k), stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N = pad::NTuple{N,Integer} = map(_->0,2 .* k),
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ, DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
stride = stride, pad = pad) stride = stride, pad = pad)
@ -98,7 +182,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
function (c::DepthwiseConv)(x) function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
end end
function Base.show(io::IO, l::DepthwiseConv) function Base.show(io::IO, l::DepthwiseConv)
@ -108,6 +193,12 @@ function Base.show(io::IO, l::DepthwiseConv)
print(io, ")") print(io, ")")
end end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
""" """
MaxPool(k) MaxPool(k)
@ -115,16 +206,23 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad` and `stride`.
""" """
struct MaxPool{N} struct MaxPool{N,M}
k::NTuple{N,Int} k::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{M,Int}
stride::NTuple{N,Int} stride::NTuple{N,Int}
end end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride)) stride = expand(Val(N), stride)
pad = expand(Val(2*N), pad)
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) return MaxPool(k, pad, stride)
end
function (m::MaxPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return maxpool(x, pdims)
end
function Base.show(io::IO, m::MaxPool) function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
@ -137,16 +235,22 @@ Mean pooling layer. `k` stands for the size of the window for each dimension of
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad` and `stride`.
""" """
struct MeanPool{N} struct MeanPool{N,M}
k::NTuple{N,Int} k::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{M,Int}
stride::NTuple{N,Int} stride::NTuple{N,Int}
end end
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride)) stride = expand(Val(N), stride)
pad = expand(Val(2*N), pad)
return MeanPool(k, pad, stride)
end
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) function (m::MeanPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return meanpool(x, pdims)
end
function Base.show(io::IO, m::MeanPool) function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")

View File

@ -61,6 +61,37 @@ end
_testmode!(a::Dropout, test) = (a.active = !test) _testmode!(a::Dropout, test) = (a.active = !test)
"""
AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
"""
mutable struct AlphaDropout{F}
p::F
active::Bool
end
function AlphaDropout(p)
@assert 0 p 1
AlphaDropout(p,true)
end
function (a::AlphaDropout)(x)
a.active || return x
λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α)
noise = randn(eltype(x), size(x))
x = @. x*(noise > (1 - a.p)) + α1 * (noise <= (1 - a.p))
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
B = -A * α1 * (1 - a.p)
x = @. A * x + B
return x
end
_testmode!(a::AlphaDropout, test) = (a.active = !test)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)
@ -124,39 +155,39 @@ mutable struct BatchNorm{F,V,W,N}
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum, true)
function (BN::BatchNorm)(x) function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) || size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x)) dims = length(size(x))
channels = size(x, dims-1) channels = size(x, dims-1)
affine_shape = ones(Int, dims) affine_shape = ones(Int, dims)
affine_shape[end-1] = channels affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end] m = prod(size(x)[1:end-2]) * size(x)[end]
γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...)
if !BN.active if !BN.active
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ
else else
T = eltype(x) T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes) μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
ϵ = data(convert(T, BN.ϵ))
# update moving mean/std # update moving mean/std
mtm = data(convert(T, BN.momentum)) mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
end end
let λ = BN.λ let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...)) = (x .- μ) ./ sqrt.(σ² .+ ϵ)
λ.(γ .* .+ β)
end end
end end
@ -173,3 +204,209 @@ function Base.show(io::IO, l::BatchNorm)
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")") print(io, ")")
end end
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Instance Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example:
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
mutable struct InstanceNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
ndims(x) > 2 ||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
# these are repeated later on depending on the batch size
dims = length(size(x))
c = size(x, dims-1)
bs = size(x, dims)
affine_shape = ones(Int, dims)
affine_shape[end-1] = c
affine_shape[end] = bs
m = prod(size(x)[1:end-2])
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
if !in.active
μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ
else
T = eltype(x)
ϵ = data(convert(T, in.ϵ))
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
μ = mean(x, dims = axes)
σ² = mean((x .- μ) .^ 2, dims = axes)
# update moving mean/std
mtm = data(convert(T, in.momentum))
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
end
let λ = in.λ
= (x .- μ) ./ sqrt.(σ² .+ ϵ)
λ.(γ .* .+ β)
end
end
children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
_testmode!(in::InstanceNorm, test) = (in.active = !test)
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
"""
Group Normalization.
This layer can outperform Batch-Normalization and Instance-Normalization.
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
``chs`` is the number of channels, the channel dimension of your input.
For an array of N dimensions, the (N-1)th index is the channel dimension.
``G`` is the number of groups along which the statistics would be computed.
The number of channels must be an integer multiple of the number of groups.
Example:
```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
```
Link : https://arxiv.org/pdf/1803.08494.pdf
"""
mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
zeros(G,1), ones(G,1), ϵ, momentum, true)
function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
dims = length(size(x))
groups = gn.G
channels = size(x, dims-1)
batches = size(x,dims)
channels_per_group = div(channels,groups)
affine_shape = ones(Int, dims)
# Output reshaped to (W,H...,C/G,G,N)
affine_shape[end-1] = channels
μ_affine_shape = ones(Int,dims + 1)
μ_affine_shape[end-1] = groups
m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
ϵ = gn.ϵ
else
T = eltype(x)
og_shape = size(x)
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
σ² = mean((y .- μ) .^ 2, dims = axes)
ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
end
let λ = gn.λ
= (y .- μ) ./ sqrt.(σ² .+ ϵ)
# Reshape x̂
= reshape(,og_shape)
λ.(γ .* .+ β)
end
end
children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

View File

@ -84,7 +84,7 @@ end
RNNCell(in::Integer, out::Integer, σ = tanh; RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) = init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)), RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(init(out))) param(init(out)), param(zeros(out)))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
@ -122,8 +122,8 @@ end
function LSTMCell(in::Integer, out::Integer; function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform) init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)), cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
param(init(out)), param(init(out))) param(zeros(out)), param(zeros(out)))
cell.b.data[gate(out, 2)] .= 1 cell.b.data[gate(out, 2)] .= 1
return cell return cell
end end
@ -153,7 +153,7 @@ Base.show(io::IO, l::LSTMCell) =
Long Short Term Memory recurrent layer. Behaves like an RNN but generally Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences. exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals. for a good overview of the internals.
""" """
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
@ -169,7 +169,7 @@ end
GRUCell(in, out; init = glorot_uniform) = GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)), GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zeros(out*3)), param(init(out))) param(init(out*3)), param(zeros(out)))
function (m::GRUCell)(h, x) function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1) b, o = m.b, size(h, 1)
@ -194,7 +194,7 @@ Base.show(io::IO, l::GRUCell) =
Gated Recurrent Unit layer. Behaves like an RNN but generally Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences. exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals. for a good overview of the internals.
""" """
GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

View File

@ -8,8 +8,6 @@ function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) * 1 // size(y, 2) -sum(y .* log.() .* weight) * 1 // size(y, 2)
end end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2) return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
end end
@ -42,12 +40,17 @@ but it is more numerically stable.
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ) logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
""" """
normalise(x::AbstractVecOrMat) normalise(x::AbstractArray; dims=1)
Normalise each column of `x` to mean 0 and standard deviation 1. Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
""" """
function normalise(x::AbstractVecOrMat) function normalise(x::AbstractArray; dims=1)
μ′ = mean(x, dims = 1) μ′ = mean(x, dims = dims)
σ = std(x, dims = 1, mean = μ′) σ = std(x, dims = dims, mean = μ′, corrected=false)
return (x .- μ′) ./ σ return (x .- μ′) ./ σ
end end
function normalise(x::AbstractArray, dims)
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
normalise(x, dims = dims)
end

View File

@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)
A::AbstractMatrix * b::OneHotVector = A[:, b.ix] A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
@ -18,9 +20,12 @@ end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
@ -39,6 +44,29 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end end
"""
onehot(l, labels[, unk])
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
it will error.
## Examples
```jldoctest
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
false
true
false
julia> onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
false
false
true
```
"""
function onehot(l, labels) function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0) i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels") i > 0 || error("Value $l is not in labels")
@ -51,16 +79,53 @@ function onehot(l, labels, unk)
OneHotVector(i, length(labels)) OneHotVector(i, length(labels))
end end
"""
onehotbatch(ls, labels[, unk...])
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
## Examples
```jldoctest
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix:
false true false
true false true
false false false
```
"""
onehotbatch(ls, labels, unk...) = onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
Base.argmax(xs::OneHotVector) = xs.ix
"""
onecold(y[, labels = 1:length(y)])
Inverse operations of [`onehot`](@ref).
## Examples
```jldoctest
julia> onecold([true, false, false], [:a, :b, :c])
:a
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
"""
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
onecold(y::AbstractMatrix, labels...) = onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
onecold(y::OneHotMatrix, labels...) =
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
function argmax(xs...) function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax) Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...) return onecold(xs...)
end end
@ -68,3 +133,6 @@ end
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b) a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b) a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)

View File

@ -4,7 +4,7 @@ using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule # legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps) updaterule(opt, ps) = () -> _update_params!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.) function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
@ -117,7 +117,7 @@ struct OldOptimiser
func func
end end
update!(opt::OldOptimiser, ps) = opt.func() _update_params!(opt::OldOptimiser, ps) = opt.func()
# Train function # Train function
function train!(loss, data, opt; cb = () -> ()) function train!(loss, data, opt; cb = () -> ())

View File

@ -18,7 +18,7 @@ end
Descent() = Descent(0.1) Descent() = Descent(0.1)
function update!(o::Descent, x, Δ) function apply!(o::Descent, x, Δ)
Δ .*= o.eta Δ .*= o.eta
end end
@ -35,9 +35,9 @@ end
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ) function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(data(x))
@. v = ρ * v - η * Δ @. v = ρ * v - η * Δ
@. Δ = -v @. Δ = -v
end end
@ -55,9 +55,9 @@ end
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ) function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(data(x))
d = @. ρ^2 * v - (1+ρ) * η * Δ d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ @. v = ρ*v - η*Δ
@. Δ = -d @. Δ = -d
@ -66,7 +66,7 @@ end
""" """
RMSProp(η = 0.001, ρ = 0.9) RMSProp(η = 0.001, ρ = 0.9)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks. choice for recurrent networks.
""" """
@ -78,9 +78,9 @@ end
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function update!(o::RMSProp, x, Δ) function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x) acc = get!(o.acc, x, zero(x))::typeof(data(x))
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ) @. Δ *= η / (acc + ϵ)
end end
@ -98,7 +98,7 @@ end
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
function update!(o::ADAM, x, Δ) function apply!(o::ADAM, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β)) mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -122,7 +122,7 @@ end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ) function apply!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β)) mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -145,9 +145,9 @@ end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ) function apply!(o::ADAGrad, x, Δ)
η = o.eta η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
@. acc += Δ^2 @. acc += Δ^2
@. Δ *= η / (acc + ϵ) @. Δ *= η / (acc + ϵ)
end end
@ -155,7 +155,7 @@ end
""" """
ADADelta(ρ = 0.9, ϵ = 1e-8) ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need [ADADelta](https://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning. tuning.
""" """
mutable struct ADADelta mutable struct ADADelta
@ -165,7 +165,7 @@ end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ) function apply!(o::ADADelta, x, Δ)
ρ = o.rho ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x))) acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@ -188,7 +188,7 @@ end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ) function apply!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x)))) mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -211,7 +211,7 @@ end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ) function apply!(o::NADAM, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
β1p, β2p = o.beta β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x))) mt, vt = get!(o.state, x, (zero(x), zero(x)))
@ -250,9 +250,9 @@ Optimiser(o...) = Optimiser(Any[o...])
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ) function apply!(o::Optimiser, x, Δ)
for opt in o.os for opt in o.os
Δ = update!(opt, x, Δ) Δ = apply!(opt, x, Δ)
end end
return Δ return Δ
end end
@ -272,7 +272,7 @@ end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ) function apply!(o::InvDecay, x, Δ)
γ = o.gamma γ = o.gamma
n = get!(o.state, x, 1) n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n) Δ .*= 1 / (1 + γ * n)
@ -300,14 +300,14 @@ end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ) function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1 n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip) η = max(η * decay^(s / n), o.clip)
o.eta = η o.eta = η
end end
@. Δ *= decay @. Δ *= η
end end
""" """
@ -321,7 +321,7 @@ end
WeightDecay() = WeightDecay(0) WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ) function apply!(o::WeightDecay, x, Δ)
wd = o.wd wd = o.wd
@. Δ += wd * x @. Δ += wd * data(x)
end end

View File

@ -1,12 +1,23 @@
using Juno using Juno
using Flux.Tracker: data, grad, back! import Flux.Tracker: Params, gradient, data, update!
import Base.depwarn import Base.depwarn
function update!(opt, xs) function update!(opt, x, )
update!(x, -apply!(opt, x, data()))
end
function update!(opt, xs::Params, gs)
for x in xs for x in xs
Δ = update!(opt, x.data, x.grad) update!(opt, x, gs[x])
x.data .-= Δ end
Δ .= 0 end
# Added as an internal API but everyone started using it.
function _update_params!(opt, xs)
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
for x in xs
update!(opt, x, Tracker.grad(x))
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
end end
end end
@ -15,16 +26,6 @@ call(f, xs...) = f(xs...)
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
struct StopException <: Exception end struct StopException <: Exception end
""" """
stop() stop()
@ -63,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, ps, data, opt; cb = () -> ()) function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb) cb = runall(cb)
opt = runall(opt)
@progress for d in data @progress for d in data
try try
l = loss(d...) gs = gradient(ps) do
@interrupts back!(l) loss(d...)
update!(opt, ps) end
update!(opt, ps, gs)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break

View File

@ -1,111 +0,0 @@
module Tracker
using MacroTools
using MacroTools: @q, @forward
import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back!
tracker(x) = nothing
istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x))
grad(x) = grad(tracker(x))
grad(::Nothing) = nothing
data(x) = x
struct Call{F,As<:Tuple}
func::F
args::As
end
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
Call() = Call(nothing, ())
# When deserialising, the object_id changes
a::Call == b::Call = a.func == b.func && a.args == b.args
@inline (c::Call)() = c.func(data.(c.args)...)
mutable struct Tracked{T}
ref::UInt32
f::Call
isleaf::Bool
grad::T
Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
end
istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call()
grad(x::Tracked) = x.grad
track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end
function track(f::F, xs...; kw...) where F
y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y)
end
macro grad(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {T__} = body_)) || error("Need a function definition")
T == nothing && (T = [])
isexpr(name, :(::)) || (name = :(::typeof($name)))
insert!(args, 1+isexpr(args[1], :parameters) , name)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
include("idset.jl")
include("back.jl")
include("numeric.jl")
include("lib/real.jl")
include("lib/array.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`."""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
"""
checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models.
"""
checkpoint(f, args...) = track(checkpoint, f, args...)
@grad function checkpoint(f, args...)
data(f(args...)), function (Δ)
y, back = forward(f, args...)
(nothing, back(Δ)...)
end
end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
@grad identity(x) = data(x), Δ -> (Δ,)
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import Adapt: adapt, adapt_structure
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end

View File

@ -1,183 +0,0 @@
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
end
return
end
function scan(x)
istracked(x) && scan(tracker(x))
return
end
function back_(c::Call, Δ, once)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end
back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end
return
end
back(::Nothing, Δ, once) = return
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ, once)
return
end
function gradient_(f, xs...)
xs = param.(data.(xs))
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end
# Out-of-place gradients
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
function back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
return
end
back(::Grads, ::Nothing, _) = return
function forward(f, ps::Params)
y = f()
y, function (Δ)
g = Grads(ps)
if istracked(y)
scan(y)
back(g, tracker(y), Δ)
end
return g
end
end
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(Ref(back(Δ)), args)
end
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
function gradient_nested(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)

View File

@ -1,28 +0,0 @@
struct IdSet{T} <: AbstractSet{T}
dict::IdDict{T,Nothing}
IdSet{T}() where T = new(IdDict{T,Nothing}())
end
Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}()
Base.push!(s::IdSet) = s
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
IdSet(xs) = IdSet{eltype(xs)}(xs)
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
@forward IdSet.dict Base.length
function Base.iterate(v::IdSet, state...)
y = Base.iterate(keys(v.dict), state...)
y === nothing && return nothing
return (y[1], y[2])
end

View File

@ -1,494 +0,0 @@
import Base: *
import LinearAlgebra
import LinearAlgebra: inv, \, /
using Statistics
using LinearAlgebra: Transpose, Adjoint, diagm, diag
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
tracker::Tracked{A}
data::A
grad::A
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end
data(x::TrackedArray) = x.data
tracker(x::TrackedArray) = x.tracker
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
Base.convert(::Type{<:TrackedArray}, x::TrackedArray) =
error("Not implemented: convert $(typeof(x)) to $T")
Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
TrackedArray(convert(A, x))
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
@isdefined(A) ?
print(io, "TrackedArray{…,$A}") :
invoke(show, Tuple{IO,DataType}, io, t)
function Base.summary(io::IO, x::TrackedArray)
print(io, "Tracked ")
summary(io, data(x))
end
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
function Base.show(io::IO, x::TrackedArray)
show(io, data(x))
print(io, " (tracked)")
end
Base.copy(x::TrackedArray) = x
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
function update!(x::TrackedArray, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
# Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
size(data(x), i, j, is...)
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
for op in [:(==), :≈]
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
end
# Array Stdlib
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′[i...] = data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
end
end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@grad function view(x::AbstractArray, inds...)
view(data(x), inds...), function (Δ)
grad_output = zero(x)
subgrad = view(grad_output, inds...)
subgrad[:] = data(Δ)
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,)
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
repeat(data(xs), inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs)
S = size(xs)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
(nobacksies(:repeat, Δ′),)
end
end
function combinations(xs, n)
n < 1 && return [[]]
cs = combinations(xs, n-1)
[[x, c...] for x in xs, c in cs]
end
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
track($f, $(cnames...), x, xs...)
end
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
track($f, $(cnames...), x, xs...)
end
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
track($f, $(cnames...), x, xs...)
end
@grad function vcat(xs...)
vcat(data.(xs)...), function (Δ)
start = 0
Δs = [begin
i = map(_ -> :, size(xsi)) |> Base.tail
d = Δ[start+1:start+size(xsi,1), i...]
start += size(xsi, 1)
d
end for xsi in xs]
return (Δs...,)
end
end
@grad function hcat(xs...)
hcat(data.(xs)...), function (Δ)
start = 0
Δs = [begin
d = if ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Δ[:, start+1:start+size(xsi,2), i...]
end
start += size(xsi, 2)
d
end for xsi in xs]
return (Δs...,)
end
end
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
cnames = map(_ -> gensym(), c)
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
track(cat, $(cnames...), x, xs..., dims = dims)
end
@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = [begin
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
d
end for xs in Xs]
return (Δs...,)
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
mat1_rsh = reshape(mat1,(1,m1,1,n1))
m2, n2 = size(mat2)
mat2_rsh = reshape(mat2,(m2,1,n2,1))
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
end
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
inv(A::TrackedArray) = Tracker.track(inv, A)
@grad function inv(A)
return inv(Tracker.data(A)), function (Δ)
Ainv = inv(A)
∇A = - Ainv' * Δ * Ainv'
return (∇A, )
end
end
# (/) rdivide
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
@grad function (A / B)
return Tracker.data(A) / Tracker.data(B), function (Δ)
Binv = inv(B)
∇B = - Binv' * A' * Δ * Binv'
return (Δ * Binv', ∇B)
end
end
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
@grad function (A \ B)
return Tracker.data(A) \ Tracker.data(B), function (Δ)
Ainv = inv(A)
∇A = - Ainv' * Δ * B' * Ainv'
return (∇A, Ainv' * Δ)
end
end
# Reductions
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
Δ -> (zero(xs) .+ Δ, )
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
@grad prod(xs, dim) = prod(data(xs), dims = dim),
Δ -> (nobacksies(:sum,
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
nothing)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
import LinearAlgebra: dot
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
# Hacks to get std working
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
@grad function maximum(xs; dims = dims)
maximum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
_, i = findmax(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:maximum, Δ′),)
end
end
@grad function minimum(xs; dims = dims)
minimum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
_, i = findmin(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:minimum, Δ′),)
end
end
# BLAS
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
x::TrackedVector * y::AbstractVector = track(*, x, y)
x::AbstractVector * y::TrackedVector = track(*, x, y)
x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs)
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
@grad depthwiseconv(x, w; kw...) =
depthwiseconv(data(x), data(w); kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@grad conv(x, w; kw...) =
conv(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
@grad function maxpool(x, k; kw...)
y = maxpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
@grad function meanpool(x, k; kw...)
y = meanpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
# Broadcasting
using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue, _) = nothing
dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
return Δ * f(dargs...).partials[1]
end
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
y = broadcast(f, data.(args)...)
eltype(y) <: Real || return y
eltype(y) == Bool && return y
function back(Δ)
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
dxs = map(unbroadcast, args, Δargs)
return dxs
end
# So we can return non-tracked arrays
track(Call(back, tracker.(args)), y)
end
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
struct TrackedStyle <: BroadcastStyle end
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
# We have to re-build the original broadcast struct to get the appropriate array
# style. We need this primarily to support CuArrays' broadcasting fixes.
broadcast_rebuild(xs) = data(xs)
broadcast_rebuild(bc::Broadcasted) =
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
preprocess(x) = x
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
bc1 = Broadcast.flatten(bc)
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
∇broadcast(bc2.f, bc1.args...)
end
using Requires
# https://github.com/FluxML/Flux.jl/issues/353
if VERSION < v"1.1.0-DEV.548"
@init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
let makeargs = make_makeargs(bc), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
end
end
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs(args...)
a, b = headargs(args1...), tailargs(args1...)
(f(a...), b...)
end
end
end
end
end
end

View File

@ -1,154 +0,0 @@
mutable struct TrackedReal{T<:Real} <: Real
data::T
tracker::Tracked{T}
end
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
data(x::TrackedReal) = x.data
tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1, once = once)
end
function update!(x::TrackedReal, Δ)
x.data += data(Δ)
tracker(x).grad = 0
return x
end
function Base.show(io::IO, x::TrackedReal)
T = get(io, :typeinfo, Any)
show(io, data(x))
T <: TrackedReal || print(io, " (tracked)")
end
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.copy(x::TrackedReal) = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
for op in [:(==), :≈, :<, :(<=)]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
Base.eps(x::TrackedReal) = eps(data(x))
Base.eps(::Type{TrackedReal{T}}) where T = eps(T)
for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end
Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
Base.float(x::TrackedReal) = x
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)}
using Random
for f in :[rand, randn, randexp].args
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
end
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
@eval begin
@grad $M.$f(a::Real) =
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
$M.$f(a::TrackedReal) = track($M.$f, a)
end
end
# Work around zero(π) not working, for some reason
_zero(::Irrational) = nothing
_zero(x) = zero(x)
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :a, :b)
f = :($M.$f)
@eval begin
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b)
$f(a::Real, b::TrackedReal) = track($f, a, b)
end
end
# Eliminating ambiguity
import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b)
# Hack for conversions
using ForwardDiff: Dual
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
# Tuples
struct TrackedTuple{T<:Tuple}
data::T
tracker::Tracked{T}
end
data(xs::TrackedTuple) = xs.data
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
print(io, " (tracked)")
end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, (tracker.(xs),)), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back_(c::Call{typeof(collect)}, Δ, once)
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
end

View File

@ -1,18 +0,0 @@
function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp
Δ[i] = (y2-y1)/δ
end
return grads
end
gradcheck(f, xs...) =
all(isapprox.(ngradient(f, xs...),
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))

View File

@ -1,11 +1,13 @@
import Adapt: adapt import Adapt: adapt, adapt_storage
import .Tracker: IdSet import .Tracker: IdSet
children(x) = () children(x) = ()
mapchildren(f, x) = x mapchildren(f, x) = x
children(x::Tuple) = x children(x::Tuple) = x
children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x) mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function treelike(m::Module, T, fs = fieldnames(T)) function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin @eval m begin
@ -14,11 +16,6 @@ function treelike(m::Module, T, fs = fieldnames(T))
end end
end end
function treelike(T, fs = fieldnames(T))
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
treelike(Base._current_module(), T, fs)
end
macro treelike(T, fs = nothing) macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)") fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
@ -69,3 +66,22 @@ gpu_adaptor = identity
end end
gpu(x) = mapleaves(gpu_adaptor, x) gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
x isa Union{AbstractArray,Number} ? f(x) :
x
end
end

View File

@ -10,8 +10,8 @@ zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
""" """
chunk(xs, n) chunk(xs, n)
@ -139,25 +139,6 @@ function throttle(f, timeout; leading=true, trailing=false)
end end
end end
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .= 0 # Reset gradient accumulator
end
J'
end
""" """
@jit ... @jit ...

View File

@ -11,6 +11,8 @@ x = param(randn(5, 5))
cx = gpu(x) cx = gpu(x)
@test cx isa TrackedArray && cx.data isa CuArray @test cx isa TrackedArray && cx.data isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3) x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x) cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray @test cx isa Flux.OneHotMatrix && cx.data isa CuArray
@ -36,8 +38,16 @@ Flux.back!(sum(l))
end end
@testset "onecold gpu" begin
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
@test Flux.onecold(y) isa CuArray
@test y[3,:] isa CuArray
end
if CuArrays.libcudnn != nothing if CuArrays.libcudnn != nothing
@info "Testing Flux/CUDNN" @info "Testing Flux/CUDNN"
include("cudnn.jl") include("cudnn.jl")
if !haskey(ENV, "CI_DISABLE_CURNN_TEST")
include("curnn.jl") include("curnn.jl")
end
end end

View File

@ -14,3 +14,9 @@ using Test
@test FashionMNIST.labels() isa Vector{Int64} @test FashionMNIST.labels() isa Vector{Int64}
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}} @test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
@test Iris.features() isa Matrix
@test size(Iris.features()) == (4,150)
@test Iris.labels() isa Vector{String}
@test size(Iris.labels()) == (150,)

View File

@ -1,6 +1,18 @@
using Test, Random using Test, Random
import Flux: activations
@testset "basic" begin @testset "basic" begin
@testset "helpers" begin
@testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10)
@test activations(Chain(), x) == []
@test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
end
end
@testset "Chain" begin @testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10)) @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
@ -30,4 +42,34 @@ using Test, Random
@test Flux.Diagonal(2)([1,2]) == [1,2] @test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end end
@testset "Maxout" begin
# Note that the normal common usage of Maxout is as per the docstring
# These are abnormal constructors used for testing purposes
@testset "Constructor" begin
mo = Maxout(() -> identity, 4)
input = rand(40)
@test mo(input) == input
end
@testset "simple alternatives" begin
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
input = rand(40)
@test mo(input) == 2*input
end
@testset "complex alternatives" begin
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
input = [3.0 2.0]
target = [0.5, 0.7].*input
@test mo(input) == target
end
@testset "params" begin
mo = Maxout(()->Dense(32, 64), 4)
ps = params(mo)
@test length(ps) == 8 #4 alts, each with weight and bias
end
end
end end

View File

@ -4,9 +4,9 @@ using Flux: maxpool, meanpool
@testset "Pooling" begin @testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2) x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2)) mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2)) @test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2)) mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2)) @test mp(x) == meanpool(x, PoolDims(x, 2))
end end
@testset "CNN" begin @testset "CNN" begin
@ -21,3 +21,69 @@ end
@test size(m(r)) == (10, 5) @test size(m(r)) == (10, 5)
end end
@testset "asymmetric padding" begin
r = ones(Float32, 28, 28, 1, 1)
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
m.weight.data[:] .= 1.0
m.bias.data[:] .= 0.0
y_hat = Flux.data(m(r))[:,:,1,1]
@test size(y_hat) == (27, 29)
@test y_hat[1, 1] 6.0
@test y_hat[2, 2] 9.0
@test y_hat[end, 1] 4.0
@test y_hat[1, end] 3.0
@test y_hat[1, end-1] 6.0
@test y_hat[end, end] 2.0
end
@testset "Depthwise Conv" begin
r = zeros(Float32, 28, 28, 3, 5)
m1 = DepthwiseConv((2, 2), 3=>5)
@test size(m1(r), 3) == 15
m2 = DepthwiseConv((2, 2), 3)
@test size(m2(r), 3) == 3
x = zeros(Float64, 28, 28, 3, 5)
m3 = DepthwiseConv((2, 2), 3 => 5)
@test size(m3(r), 3) == 15
m4 = DepthwiseConv((2, 2), 3)
@test size(m4(r), 3) == 3
end
@testset "ConvTranspose" begin
x = zeros(Float32, 28, 28, 1, 1)
y = Conv((3,3), 1 => 1)(x)
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
@test size(x_hat) == size(x)
end
@testset "Conv with non quadratic window #700" begin
data = zeros(Float32, 7,7,1,1)
data[4,4,1,1] = 1
l = Conv((3,3), 1=>1)
expected = zeros(eltype(l.weight),5,5,1,1)
expected[2:end-1,2:end-1,1,1] = l.weight
@test expected == l(data)
l = Conv((3,1), 1=>1)
expected = zeros(eltype(l.weight),5,7,1,1)
expected[2:end-1,4,1,1] = l.weight
@test expected == l(data)
l = Conv((1,3), 1=>1)
expected = zeros(eltype(l.weight),7,5,1,1)
expected[4,2:end-1,1,1] = l.weight
@test expected == l(data)
@test begin
# we test that the next expression does not throw
randn(Float32, 10,10,1,1) |> Conv((6,1), 1=>1, Flux.σ)
true
end
end

View File

@ -108,4 +108,216 @@ end
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y @test m(x) == y
end end
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
end
end
@testset "InstanceNorm" begin
# helper functions
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
@test m.active
m(x)
#julia> x
#[:, :, 1] =
# 1.0 4.0
# 2.0 5.0
# 3.0 6.0
#
#[:, :, 2] =
# 7.0 10.0
# 8.0 11.0
# 9.0 12.0
#
# μ will be
# (1. + 2. + 3.) / 3 = 2.
# (4. + 5. + 6.) / 3 = 5.
#
# (7. + 8. + 9.) / 3 = 8.
# (10. + 11. + 12.) / 3 = 11.
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
@test m.μ [0.5, 0.8]
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
# 2-element Array{Float64,1}:
# 1.
# 1.
@test m.σ² reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
testmode!(m)
@test !m.active
x = m(x).data
@test isapprox(x[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
affine_shape = collect(sizes)
affine_shape[1] = 1
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x).data
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
end
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes
end
# show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
end
end
@testset "GroupNorm" begin
# begin tests
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m.β.data == [0, 0, 0, 0] # initβ(32)
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
@test m.active
m(x)
#julia> x
#[:, :, 1] =
# 1.0 4.0 7.0 10.0
# 2.0 5.0 8.0 11.0
# 3.0 6.0 9.0 12.0
#
#[:, :, 2] =
# 13.0 16.0 19.0 22.0
# 14.0 17.0 20.0 23.0
# 15.0 18.0 21.0 24.0
#
# μ will be
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
#
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
#
# μ =
# 3.5 15.5
# 9.5 21.5
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
@test m.μ [0.95, 1.55]
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
# 2-element Array{Tracker.TrackedReal{Float64},1}:
# 1.25
# 1.25
@test m.σ² mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
testmode!(m)
@test !m.active
x = m(x).data
println(x[1])
@test isapprox(x[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
μ_affine_shape = ones(Int,length(sizes) + 1)
μ_affine_shape[end-1] = 2 # Number of groups
affine_shape = ones(Int,length(sizes) + 1)
affine_shape[end-2] = 2 # Channels per group
affine_shape[end-1] = 2 # Number of groups
affine_shape[1] = sizes[1]
affine_shape[end] = sizes[end]
og_shape = size(x)
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x)
x_ = reshape(x,affine_shape...)
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
@test isapprox(y, out, atol = 1.0e-7)
end
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1)
@test size(y) == sizes
end
# show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x)
end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x)
end
end end

View File

@ -11,3 +11,9 @@ using Test
@test onecold(a, labels) == 'C' @test onecold(a, labels) == 'C'
@test onecold(A, labels) == ['C', 'A', 'D'] @test onecold(A, labels) == ['C', 'A', 'D']
end end
@testset "onehotbatch indexing" begin
y = Flux.onehotbatch(ones(3), 1:10)
@test y[:,1] isa Flux.OneHotVector
@test y[:,:] isa Flux.OneHotMatrix
end

View File

@ -4,21 +4,15 @@ using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum] @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5 for t = 1: 10^5
l = loss(rand(10)) θ = Params([w])
back!(l) θ̄ = gradient(() -> loss(rand(10)), θ)
delta = Optimise.update!(opt, w.data, w.grad) Optimise.update!(opt, θ, θ̄)
w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end
@ -33,7 +27,7 @@ end
for t = 1:10^5 for t = 1:10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
delta = Optimise.update!(opt, w.data, w.grad) delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
@ -59,3 +53,36 @@ end
cbs() cbs()
@test x == 1 @test x == 1
end end
@testset "ExpDecay" begin
w = randn(10, 10)
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
w1 = param(randn(10,10))
loss(x) = Flux.mse(w*x, w1*x)
flag = 1
decay_steps = []
for t = 1:10^5
l = loss(rand(10))
back!(l)
prev_eta = o.eta
prev_grad = collect(w1.grad)
delta = Optimise.apply!(o, w1.data, w1.grad)
w1.data .-= delta
new_eta = o.eta
if new_eta != prev_eta
push!(decay_steps, t)
end
array = fill(o.eta, size(prev_grad))
if array .* prev_grad != delta
flag = 0
end
end
@test flag == 1
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
ground_truth = []
for i in 1:11
push!(ground_truth, 1000*i) # Expected decay steps for this example.
end
@test decay_steps == ground_truth
@test o.eta == o.clip
end

View File

@ -1,308 +1,15 @@
using Flux using Flux, Test
using Flux.Tracker, Test, NNlib using Tracker: gradcheck
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std
using Random
# using StatsBase
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
@test gradtest(x -> sum(x), randn(Float64,2,3))
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3) @testset "Tracker" begin
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
@test gradtest(Flux.mse, rand(5,5), rand(5, 5)) @test gradtest(Flux.mse, rand(5,5), rand(5, 5))
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5)) @test gradtest(x -> Flux.normalise(x), rand(4,3))
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
r2 = f(A, param(B), C)
r3 = f(A, B, param(C))
r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
@test r0 == Flux.data(r4)
end
@testset "concat" begin
cat1(x...) = cat(x..., dims = 1)
cat2(x...) = cat(x..., dims = 2)
@testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8))
@test gradtest(vcatf, rand(5)', rand(5)')
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
@test gradtest(vcatf, rand(5), rand(3,1))
@test gradtest(vcatf, rand(5)', rand(2,5))
end
@testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
@test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2))
end
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
end
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(x..., dims = dim)
@test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
end
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(rand(2), dims=1), TrackedArray)
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
promotiontest(fcat, rand(2), rand(2), rand(2))
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
end
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
end end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(x -> diagm(0 => x), rand(3))
@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
@testset "mean" begin
@test gradtest(mean, rand(2, 3))
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
end
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
end
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, dims = 1), rand(5,5))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))
@test gradtest(norm, rand(5))
@test gradtest(rand(5)) do x
y = x.^2
2y + x
end
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test gradtest(x -> Float64.(x), 5)
@testset "equality & order" begin
# TrackedReal
@test param(2)^2 == param(4)
@test param(2)^2 == 4
@test 4 == param(2)^2
@test param(2)^2 param(4)
@test param(2)^2 4
@test 4 param(2)^2
@test (param([1,2,3]) .< 2) == [true, false, false]
@test (param([1,2,3]) .<= 2) == [true, true, false]
@test (2 .> param([1,2,3])) == [true, false, false]
@test (2 .>= param([1,2,3])) == [true, true, false]
# TrackedArray
@test param([1,2,3]).^2 == param([1,4,9])
@test [1,2,3].^2 == param([1,4,9])
@test param([1,2,3]).^2 == [1,4,9]
@test param([1,2,3]).^2 param([1,4,9])
@test [1,2,3].^2 param([1,4,9])
@test param([1,2,3]).^2 [1,4,9]
end
@testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2)
@test x isa TrackedArray
@test size(x) == (4,2)
x = reshape(param([1]), (1,:))
@test x isa TrackedArray
@test size(x) == (1,1)
x = reshape(param(rand(2)), (2,:))
@test x isa TrackedArray
@test size(x) == (2,1)
x = reshape(param(rand(2,2)), (1,:,2))
@test x isa TrackedArray
@test size(x) == (1,2,2)
end
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l, once = false)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l, once = false)
@test x.grad == [8]
end
@testset "Fallbacks" begin
xs = param([1 2; 3 4])
@test similar(xs) isa Matrix{Float64}
end
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
b = param(rand())
Tracker.back!(b)
@test Tracker.grad(b) == 1
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
@test gradient(2, 3) do x, y
xy = Tracker.collect([x, y])
xy[1]*xy[2]
end == (3, 2)
end
# Gradient Hooks
@testset "Hooks" begin
x = param(2)
y = Tracker.hook(-, x)
back!(y)
@test grad(x) == -1
end
@testset "Checkpointing" begin
count = 0
function mul(a, b)
count += 1
a * b
end
@test gradient(x -> mul(5, x), 3)[1] == 5
@test count == 1
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
@test count == 3
end
@testset "Updates" begin
xs = param([1, 2, 3])
Tracker.update!(xs, param([4, 5, 6]))
@test xs == [5, 7, 9]
x = param(3)
Tracker.update!(x, param(4))
@test x == 7
end
@testset "Params" begin
W = param(randn(5, 10))
x = rand(10)
dW = gradient(W -> sum(W*x), W)[1]
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
@test gs[W] == dW
end
end #testset

View File

@ -1,5 +1,5 @@
using Flux using Flux
using Flux: throttle, jacobian, glorot_uniform, glorot_normal using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack
using StatsBase: std using StatsBase: std
using Random using Random
using Test using Test
@ -86,3 +86,28 @@ end
m = RNN(10, 5) m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end end
@testset "Basic Stacking" begin
x = randn(3,3)
stacked = stack([x, x], 2)
@test size(stacked) == (3,2,3)
end
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
@test eltype(m[1].W.data) == Float32
@test eltype(m(x).data) == Float32
@test eltype(f64(m)(x).data) == Float64
@test eltype(f64(m)[1].W.data) == Float64
@test eltype(f32(f64(m))[1].W.data) == Float32
@test Tracker.isleaf(f32(f64(m))[1].W)
end
@testset "Stacking" begin
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unstack(stacked_array, 2) == unstacked_array
@test stack(unstacked_array, 2) == stacked_array
@test stack(unstack(stacked_array, 1), 1) == stacked_array
end