Merge branch 'master' into onecold

This commit is contained in:
Dhairya Gandhi 2019-04-30 19:09:36 +05:30 committed by GitHub
commit 9bbbd17e4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 1349 additions and 1726 deletions

37
.gitlab-ci.yml Normal file
View File

@ -0,0 +1,37 @@
before_script:
- export CI_DISABLE_CURNN_TEST=true
variables:
CI_IMAGE_TAG: 'cuda'
include:
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml'
.flux:
extends: .test
script:
- julia -e 'using InteractiveUtils;
versioninfo()'
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
- julia -e 'using Pkg;
Pkg.add("CuArrays");'
- julia --project -e 'using Pkg;
Pkg.instantiate();
Pkg.build();
Pkg.test(; coverage=true);'
test:v1.0:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.0'
only:
- staging
- trying
test:v1.1:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.1'
only:
- staging
- trying

View File

@ -1,6 +1,6 @@
The Flux.jl package is licensed under the MIT "Expat" License:
> Copyright (c) 2016: Mike Innes.
> Copyright (c) 2016-19: Julia Computing, INc., Mike Innes and Contributors
>
> Permission is hereby granted, free of charge, to any person obtaining
> a copy of this software and associated documentation files (the

View File

@ -1,3 +1,5 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
@ -25,11 +27,17 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
version = "0.5.2"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
@ -51,9 +59,15 @@ version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.5.1"
version = "2.1.0"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
@ -71,18 +85,18 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
version = "0.0.4"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9"
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.8"
version = "0.0.10"
[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]]
@ -93,19 +107,19 @@ version = "0.5.3"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.2"
version = "0.10.3"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.4"
version = "0.7.0"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -121,10 +135,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
version = "0.5.0"
[[Markdown]]
deps = ["Base64"]
@ -146,12 +160,10 @@ version = "0.4.0"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f"
repo-rev = "master"
repo-url = "https://github.com/FluxML/NNlib.jl.git"
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3+"
version = "0.6.0"
[[NaNMath]]
deps = ["Compat"]
@ -161,9 +173,9 @@ version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
version = "1.1.0"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@ -228,29 +240,47 @@ version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
version = "0.10.3"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
version = "0.30.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.0"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
deps = ["Random", "Test"]
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
version = "0.9.4"
[[URIParser]]
deps = ["Test", "Unicode"]
@ -259,7 +289,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random"]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
@ -267,6 +297,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"
version = "0.8.1"

25
NEWS.md Normal file
View File

@ -0,0 +1,25 @@
# v0.8.0
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
* New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647)
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/).
* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592)
* Many docs and bugfixes thanks to @KristofferC and others.
* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`.
* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615).
* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs.
* New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656).
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
AD Changes:
* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files).
* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576).
* Flux.Tracker is now its [own package](https://github.com/FluxML/Tracker.jl), in preparation for replacing it with Zygote.
# v0.7.0
Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.

View File

@ -6,21 +6,29 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
[compat]
NNlib = "0.6"
Tracker = "0.2"
julia = "0.7, 1"
[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
test = ["Test"]

View File

@ -2,7 +2,7 @@
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
</p>
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](http://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602)
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](https://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602)
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
@ -10,7 +10,7 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, a
julia> Pkg.add("Flux")
```
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
See the [documentation](https://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
If you use Flux in research, please cite the following paper:

View File

@ -10,9 +10,4 @@ ZipFile
AbstractTrees
Reexport
StatsBase
# AD
ForwardDiff 0.5.0
DiffRules
SpecialFunctions
NaNMath
Tracker

4
bors.toml Normal file
View File

@ -0,0 +1,4 @@
status = [
"ci/gitlab/%"
]
timeout-sec = 14400

View File

@ -1,3 +1,5 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
@ -6,9 +8,9 @@ version = "0.2.1"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.1"
version = "0.4.2"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -25,11 +27,17 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
version = "0.5.2"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
@ -51,9 +59,15 @@ version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
version = "2.1.0"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
@ -71,31 +85,31 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
version = "0.0.4"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
version = "0.0.10"
[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7"
git-tree-sha1 = "4d30e889c9f106a51ffa4791a88ffd4765bf20c3"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.6.0"
version = "0.7.0"
[[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.21.0"
version = "0.22.3"
[[FixedPointNumbers]]
deps = ["Test"]
@ -104,26 +118,32 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"]
path = ".."
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.6.10+"
version = "0.8.2+"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.1"
version = "0.10.3"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[JSON]]
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.20.0"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.3"
version = "0.7.0"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -139,10 +159,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"
version = "0.5.0"
[[Markdown]]
deps = ["Base64"]
@ -156,18 +176,18 @@ version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.3.1"
version = "0.4.0"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"
version = "0.6.0"
[[NaNMath]]
deps = ["Compat"]
@ -177,9 +197,9 @@ version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
version = "1.1.0"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@ -244,29 +264,47 @@ version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.2"
version = "0.10.3"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.27.0"
version = "0.30.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.0"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
deps = ["Random", "Test"]
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
version = "0.9.4"
[[URIParser]]
deps = ["Test", "Unicode"]
@ -275,7 +313,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random"]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
@ -283,6 +321,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"
version = "0.8.1"

View File

@ -1,7 +1,7 @@
using Documenter, Flux, NNlib
makedocs(modules=[Flux, NNlib],
doctest = false,
doctest = true,
analytics = "UA-36890222-9",
sitename = "Flux",
# Uncomment below for local build
@ -19,6 +19,7 @@ makedocs(modules=[Flux, NNlib],
"One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md",
"Performance Tips" => "performance.md",
"Internals" =>
["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"])

View File

@ -4,45 +4,53 @@
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
```julia
using Flux.Tracker
```jldoctest basics
julia> using Flux.Tracker
f(x) = 3x^2 + 2x + 1
julia> f(x) = 3x^2 + 2x + 1;
# df/dx = 6x + 2
df(x) = Tracker.gradient(f, x; nest = true)[1]
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
df(2) # 14.0 (tracked)
julia> df(2)
14.0 (tracked)
# d²f/dx² = 6
d2f(x) = Tracker.gradient(df, x; nest = true)[1]
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
d2f(2) # 6.0 (tracked)
julia> d2f(2)
6.0 (tracked)
```
(We'll learn more about why these numbers show up as `(tracked)` below.)
When a function has many parameters, we can pass them all in explicitly:
```julia
f(W, b, x) = W * x + b
```jldoctest basics
julia> f(W, b, x) = W * x + b;
Tracker.gradient(f, 2, 3, 4)
# (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
julia> Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
```julia
W = param(2) # 2.0 (tracked)
b = param(3) # 3.0 (tracked)
```jldoctest basics
julia> using Flux
f(x) = W * x + b
julia> W = param(2)
2.0 (tracked)
grads = Tracker.gradient(() -> f(4), params(W, b))
julia> b = param(3)
3.0 (tracked)
grads[W] # 4.0
grads[b] # 1.0
julia> f(x) = W * x + b;
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
julia> grads[W]
4.0 (tracked)
julia> grads[b]
1.0 (tracked)
```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.

View File

@ -5,14 +5,16 @@ These core layers form the foundation of almost all neural networks.
```@docs
Chain
Dense
```
## Convolution and Pooling Layers
These layers are used to build convolutional neural networks (CNNs).
```@docs
Conv
MaxPool
MeanPool
```
## Additional Convolution Layers
```@docs
DepthwiseConv
ConvTranspose
```
@ -28,6 +30,25 @@ GRU
Flux.Recur
```
## Other General Purpose Layers
These are marginally more obscure than the Basic Layers.
But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs).
```@docs
Maxout
```
# Normalisation & Regularisation
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs
Flux.testmode!
BatchNorm
Dropout
LayerNorm
```
## Activation Functions
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
@ -50,5 +71,7 @@ These layers don't affect the structure of the network but may improve training
Flux.testmode!
BatchNorm
Dropout
AlphaDropout
LayerNorm
GroupNorm
```

View File

@ -77,7 +77,7 @@ If you use the `RNN(10, 5)` constructor as opposed to `RNNCell` you'll s
```julia
julia> RNN(10, 5)
Recur(RNNCell(Dense(15, 5)))
Recur(RNNCell(10, 5, tanh))
```
## Sequences
@ -114,3 +114,13 @@ truncate!(m)
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
In general, when training with recurrent layers in your model, you'll want to call `reset!` or `truncate!` for each loss calculation:
```julia
function loss(x,y)
l = Flux.mse(m(x), y)
Flux.reset!(m)
return l
end
```

76
docs/src/performance.md Normal file
View File

@ -0,0 +1,76 @@
# Performance Tips
All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/manual/performance-tips/).
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
Below follow some Flux specific tips/reminders.
## Don't use more precision than you need.
Flux works great with all kinds of number types.
But often you do not need to be working with say `Float64` (let alone `BigFloat`).
Switching to `Float32` can give you a significant speed up,
not because the operations are faster, but because the memory usage is halved.
Which means allocations occur much faster.
And you use less memory.
## Make sure your custom activation functions preserve the type of their inputs
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
they should also preserve the type of their inputs.
A very artificial example using an activatioon function like
```
my_tanh(x) = Float64(tanh(x))
```
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
because it results in having to use slow mixed type multiplication in the dense layers.
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
you will see a large slow-down
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
E.g. the following will have run into the same problem as above:
```
leaky_tanh(x) = 0.01x + tanh(x)
```
While one could change your activation function (e.g. to use `0.01f0x`) to avoid this when ever your inputs change,
the idiomatic (and safe way) is to use `oftype`.
```
leaky_tanh(x) = oftype(x/1, 0.01) + tanh(x)
```
## Evaluate batches as Matrices of features, rather than sequences of Vector features
While it can sometimes be tempting to process your observations (feature vectors) one at a time
e.g.
```julia
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
sum(zip(xs, ys)) do (x, y_target)
y_pred = model(x) # evaluate the model
return loss(y_pred, y_target)
end
end
```
It is much faster to concatenate them into a matrix,
as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications.
Even though this means allocating new memory to store them contiguously.
```julia
x_batch = reduce(hcat, xs)
y_batch = reduce(hcat, ys)
...
function loss_total(x_batch::Matrix, y_batch::Matrix)
y_preds = model(x_batch)
sum(loss.(y_preds, y_batch))
end
```
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
This will avoid the splatting penality, and will hit the optimised `reduce` method.

View File

@ -49,5 +49,12 @@ All optimisers return an object that, when passed to `train!`, will update the p
Descent
Momentum
Nesterov
RMSProp
ADAM
AdaMax
ADAGrad
ADADelta
AMSGrad
NADAM
ADAMW
```

View File

@ -93,3 +93,11 @@ evalcb() = @show(loss(test_x, test_y))
Flux.train!(objective, ps, data, opt,
cb = throttle(evalcb, 5))
```
Calling `Flux.stop()` in a callback will exit the training loop early.
```julia
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```

View File

@ -14,7 +14,7 @@
journal = {arXiv},
volume = {abs/11712.03112},
year = {2017},
url = {http://arxiv.org/abs/1712.03112},
url = {https://arxiv.org/abs/1712.03112},
}
@online{MLPL,
@ -29,7 +29,7 @@
author = {Mike Innes and others},
title = {Generic GPU Kernels},
year = 2017,
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
url = {https://mikeinnes.github.io/2017/08/24/cudanative.html},
urldate = {2018-02-16}
}

View File

@ -6,15 +6,14 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib
include("tracker/Tracker.jl")
using .Tracker
using .Tracker: data
using Tracker
using Tracker: data
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl")

View File

@ -1,17 +1,18 @@
module CUDA
using ..CuArrays
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
using Pkg.TOML
function version_check()
minor_version = 9
major_version = 1
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0"))
if !(version.major == 0 && version.minor == minor_version)
if version.major != major_version
@warn """
Flux is only supported with CuArrays v0.$minor_version.
Try running `] pin CuArrays@0.$minor_version`.
Flux is only supported with CuArrays v$major_version.x.
Try running `] pin CuArrays@$major_version`.
"""
end
end

View File

@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
desc,handle(),ρ,states,length(states),seed)
finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
mean = zeros(CuArray{T}, dims...)
ivar = ones(CuArray{T}, dims...)
else
mean = C_NULL
ivar = C_NULL
mean = CU_NULL
ivar = CU_NULL
end
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
Cdouble, CuPtr{T}, CuPtr{T},
Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
CuPtr{T}, CuPtr{T},
Cdouble),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache")
else
mean, ivar = C_NULL, C_NULL
mean, ivar = CU_NULL, CU_NULL
end
if eps < BATCHNORM_MIN_EPS
@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)),

View File

@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t),
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
@ -121,7 +121,7 @@ end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, C_NULL
hDesc(h::Nothing) = C_NULL, CU_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
Ptr{Ptr{Nothing}}, CuPtr{T}, #x
Ptr{Nothing}, CuPtr{T}, #hx
Ptr{Ptr{Nothing}}, CuPtr{T}, #y
CuPtr{Nothing}, Csize_t, #ws
Ptr{Nothing}, CuPtr{T}, #dw
CuPtr{Nothing}, Csize_t), #rs
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end

View File

@ -39,4 +39,7 @@ include("tree.jl")
include("sentiment.jl")
using .Sentiment
include("iris.jl")
export Iris
end

View File

@ -19,7 +19,7 @@ function load()
@info "Downloading CMUDict dataset"
mkpath(deps("cmudict"))
for (x, hash) in suffixes_and_hashes
download_and_verify("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
download_and_verify("$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
deps("cmudict", "cmudict$x"), hash)
end
end

86
src/data/iris.jl Normal file
View File

@ -0,0 +1,86 @@
"""
Iris
Fisher's classic iris dataset.
Measurements from 3 different species of iris: setosa, versicolor and
virginica. There are 50 examples of each species.
There are 4 measurements for each example: sepal length, sepal width, petal
length and petal width. The measurements are in centimeters.
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
"""
module Iris
using DelimitedFiles
using ..Data: deps, download_and_verify
# Uncomment if the iris.data file is cached to cache.julialang.org.
const cache_prefix = "https://cache.julialang.org/"
function load()
isfile(deps("iris.data")) && return
@info "Downloading iris dataset."
download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
deps("iris.data"),
"6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0")
end
"""
labels()
Get the labels of the iris dataset, a 150 element array of strings listing the
species of each example.
```jldoctest
julia> labels = Flux.Data.Iris.labels();
julia> summary(labels)
"150-element Array{String,1}"
julia> labels[1]
"Iris-setosa"
```
"""
function labels()
load()
iris = readdlm(deps("iris.data"), ',')
Vector{String}(iris[1:end, end])
end
"""
features()
Get the features of the iris dataset. This is a 4x150 matrix of Float64
elements. It has a row for each feature (sepal length, sepal width,
petal length, petal width) and a column for each example.
```jldoctest
julia> features = Flux.Data.Iris.features();
julia> summary(features)
"4×150 Array{Float64,2}"
julia> features[:, 1]
4-element Array{Float64,1}:
5.1
3.5
1.4
0.2
```
"""
function features()
load()
iris = readdlm(deps("iris.data"), ',')
Matrix{Float64}(iris[1:end, 1:4]')
end
end

View File

@ -40,7 +40,24 @@ function Base.show(io::IO, c::Chain)
print(io, ")")
end
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
# This is a temporary and naive implementation
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
# Johnny Chen -- @johnnychen94
"""
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
rst = []
for l in c
x = get(rst, length(rst), input)
push!(rst, l(x))
end
return rst
end
"""
Dense(in::Integer, out::Integer, σ = identity)
@ -88,6 +105,14 @@ function Base.show(io::IO, l::Dense)
print(io, ")")
end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
Diagonal(in::Integer)
@ -117,10 +142,50 @@ function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
Maxout(over)
`Maxout` is a neural network layer, which has a number of internal layers,
which all have the same input, and the maxout returns the elementwise maximium
of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem.
Reference:
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
2013. Maxout networks.
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
struct Maxout{FS<:Tuple}
over::FS
end
"""
Maxout(f, n_alts)
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
The function takes no arguement and should return some callable layer.
Conventionally this is a linear dense layer.
For example the following example which
will construct a `Maxout` layer over 4 internal dense linear layers,
each identical in structure (784 inputs, 128 outputs).
```julia
insize = 784
outsize = 128
Maxout(()->Dense(insize, outsize), 4)
```
"""
function Maxout(f, n_alts)
over = Tuple(f() for _ in 1:n_alts)
return Maxout(over)
end
@treelike Maxout
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

View File

@ -1,10 +1,7 @@
using NNlib: conv, ∇conv_data, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
"""
Conv(size, in=>out)
Conv(size, in=>out, relu)
@ -12,23 +9,36 @@ expand(N, i::Integer) = ntuple(_ -> i, N)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array.
Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
struct Conv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return Conv(σ, w, b, stride, pad, dilation)
end
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
@ -41,7 +51,8 @@ function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(conv(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::Conv)
@ -67,18 +78,22 @@ Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,F,A,V}
struct ConvTranspose{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
function ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return ConvTranspose(σ, w, b, stride, pad, dilation)
end
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
@ -87,10 +102,25 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
@treelike ConvTranspose
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
C_in = size(c.weight)[end-1]
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
stride=c.stride,
padding=c.pad,
dilation=c.dilation,
)
end
function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
cdims = conv_transpose_dims(c, x)
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::ConvTranspose)
@ -119,26 +149,32 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct DepthwiseConv{N,F,A,V}
struct DepthwiseConv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return DepthwiseConv(σ, w, b, stride, pad, dilation)
end
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
stride = 1, pad = 0) where N =
stride = 1, pad = 0, dilation = 1) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad)
stride = stride, pad = pad, dilation=dilation)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
pad::NTuple{N,Integer} = map(_->0,2 .* k),
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
stride = stride, pad = pad)
@ -146,7 +182,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::DepthwiseConv)
@ -156,6 +193,12 @@ function Base.show(io::IO, l::DepthwiseConv)
print(io, ")")
end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
MaxPool(k)
@ -163,16 +206,23 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
Takes the keyword arguments `pad` and `stride`.
"""
struct MaxPool{N}
struct MaxPool{N,M}
k::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
stride::NTuple{N,Int}
end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
stride = expand(Val(N), stride)
pad = expand(Val(2*N), pad)
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
return MaxPool(k, pad, stride)
end
function (m::MaxPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return maxpool(x, pdims)
end
function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
@ -185,16 +235,22 @@ Mean pooling layer. `k` stands for the size of the window for each dimension of
Takes the keyword arguments `pad` and `stride`.
"""
struct MeanPool{N}
struct MeanPool{N,M}
k::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
stride::NTuple{N,Int}
end
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
stride = expand(Val(N), stride)
pad = expand(Val(2*N), pad)
return MeanPool(k, pad, stride)
end
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
function (m::MeanPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return meanpool(x, pdims)
end
function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")

View File

@ -43,6 +43,37 @@ end
_testmode!(a::Dropout, test) = (a.active = !test)
"""
AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
"""
mutable struct AlphaDropout{F}
p::F
active::Bool
end
function AlphaDropout(p)
@assert 0 p 1
AlphaDropout(p,true)
end
function (a::AlphaDropout)(x)
a.active || return x
λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α)
noise = randn(eltype(x), size(x))
x = @. x*(noise > (1 - a.p)) + α1 * (noise <= (1 - a.p))
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
B = -A * α1 * (1 - a.p)
x = @. A * x + B
return x
end
_testmode!(a::AlphaDropout, test) = (a.active = !test)
"""
LayerNorm(h::Integer)
@ -113,34 +144,32 @@ BatchNorm(chs::Integer, λ = identity;
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end]
γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...)
if !BN.active
μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ
else
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
ϵ = data(convert(T, BN.ϵ))
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
end
let λ = BN.λ
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
# This is intentionally not fused because of an extreme slowdown doing so
λ.(temp .+ reshape(β, affine_shape...))
= (x .- μ) ./ sqrt.(σ² .+ ϵ)
λ.(γ .* .+ β)
end
end
@ -157,3 +186,209 @@ function Base.show(io::IO, l::BatchNorm)
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Instance Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example:
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
mutable struct InstanceNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
ndims(x) > 2 ||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
# these are repeated later on depending on the batch size
dims = length(size(x))
c = size(x, dims-1)
bs = size(x, dims)
affine_shape = ones(Int, dims)
affine_shape[end-1] = c
affine_shape[end] = bs
m = prod(size(x)[1:end-2])
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
if !in.active
μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ
else
T = eltype(x)
ϵ = data(convert(T, in.ϵ))
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
μ = mean(x, dims = axes)
σ² = mean((x .- μ) .^ 2, dims = axes)
# update moving mean/std
mtm = data(convert(T, in.momentum))
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
end
let λ = in.λ
= (x .- μ) ./ sqrt.(σ² .+ ϵ)
λ.(γ .* .+ β)
end
end
children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
_testmode!(in::InstanceNorm, test) = (in.active = !test)
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
"""
Group Normalization.
This layer can outperform Batch-Normalization and Instance-Normalization.
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
``chs`` is the number of channels, the channel dimension of your input.
For an array of N dimensions, the (N-1)th index is the channel dimension.
``G`` is the number of groups along which the statistics would be computed.
The number of channels must be an integer multiple of the number of groups.
Example:
```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
```
Link : https://arxiv.org/pdf/1803.08494.pdf
"""
mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
zeros(G,1), ones(G,1), ϵ, momentum, true)
function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
dims = length(size(x))
groups = gn.G
channels = size(x, dims-1)
batches = size(x,dims)
channels_per_group = div(channels,groups)
affine_shape = ones(Int, dims)
# Output reshaped to (W,H...,C/G,G,N)
affine_shape[end-1] = channels
μ_affine_shape = ones(Int,dims + 1)
μ_affine_shape[end-1] = groups
m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
ϵ = gn.ϵ
else
T = eltype(x)
og_shape = size(x)
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
σ² = mean((y .- μ) .^ 2, dims = axes)
ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
end
let λ = gn.λ
= (y .- μ) ./ sqrt.(σ² .+ ϵ)
# Reshape x̂
= reshape(,og_shape)
λ.(γ .* .+ β)
end
end
children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

View File

@ -153,7 +153,7 @@ Base.show(io::IO, l::LSTMCell) =
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
@ -194,7 +194,7 @@ Base.show(io::IO, l::GRUCell) =
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

View File

@ -50,7 +50,7 @@ function normalise(x::AbstractArray; dims=1)
return (x .- μ′) ./ σ
end
function normalise(x::AbstractArray, dims=1)
function normalise(x::AbstractArray, dims)
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
normalise(x, dims = dims)
end

View File

@ -44,6 +44,29 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end
"""
onehot(l, labels[, unk])
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
it will error.
## Examples
```jldoctest
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
false
true
false
julia> onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
false
false
true
```
"""
function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
@ -56,11 +79,43 @@ function onehot(l, labels, unk)
OneHotVector(i, length(labels))
end
"""
onehotbatch(ls, labels[, unk...])
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
## Examples
```jldoctest
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix:
false true false
true false true
false false false
```
"""
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
Base.argmax(xs::OneHotVector) = xs.ix
"""
onecold(y[, labels = 1:length(y)])
Inverse operations of [`onehot`](@ref).
## Examples
```jldoctest
julia> onecold([true, false, false], [:a, :b, :c])
:a
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
"""
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
onecold(y::AbstractMatrix, labels...) =

View File

@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
v = get!(o.velocity, x, zero(x))::typeof(data(x))
@. v = ρ * v - η * Δ
@. Δ = -v
end
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
v = get!(o.velocity, x, zero(x))::typeof(data(x))
d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
@ -66,7 +66,7 @@ end
"""
RMSProp(η = 0.001, ρ = 0.9)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x)
acc = get!(o.acc, x, zero(x))::typeof(data(x))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ)
end
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function apply!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
@. acc += Δ^2
@. Δ *= η / (acc + ϵ)
end
@ -155,7 +155,7 @@ end
"""
ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
mutable struct ADADelta
@ -321,7 +321,7 @@ end
WeightDecay() = WeightDecay(0)
function apply!(o::WeightDecay, x, Δ)
function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
@. Δ += wd * data(x)
end

View File

@ -1,16 +1,23 @@
using Juno
import Flux.Tracker: data, grad, back!, update!
import Flux.Tracker: Params, gradient, data, update!
import Base.depwarn
function update!(opt, x, )
update!(x, apply!(opt, x, copy(data())))
update!(x, -apply!(opt, x, data()))
end
function _update_params!(opt, xs)
function update!(opt, xs::Params, gs)
for x in xs
Δ = apply!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
update!(opt, x, gs[x])
end
end
# Added as an internal API but everyone started using it.
function _update_params!(opt, xs)
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
for x in xs
update!(opt, x, Tracker.grad(x))
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
end
end
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
struct StopException <: Exception end
"""
stop()
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
opt = runall(opt)
@progress for d in data
try
l = loss(d...)
@interrupts back!(l)
_update_params!(opt, ps)
gs = gradient(ps) do
loss(d...)
end
update!(opt, ps, gs)
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break

View File

@ -1,113 +0,0 @@
module Tracker
using MacroTools
using MacroTools: @q, @forward
import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
jacobian, hessian, param, back!
tracker(x) = nothing
istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x))
grad(x) = grad(tracker(x))
grad(::Nothing) = nothing
data(x) = x
struct Call{F,As<:Tuple}
func::F
args::As
end
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
Call() = Call(nothing, ())
# When deserialising, the object_id changes
a::Call == b::Call = a.func == b.func && a.args == b.args
@inline (c::Call)() = c.func(data.(c.args)...)
mutable struct Tracked{T}
ref::UInt32
f::Call
isleaf::Bool
grad::T
Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
end
istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call()
grad(x::Tracked) = x.grad
track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end
function track(f::F, xs...; kw...) where F
y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y)
end
macro grad(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {T__} = body_)) || error("Need a function definition")
T == nothing && (T = [])
isexpr(name, :(::)) || (name = :(::typeof($name)))
insert!(args, 1+isexpr(args[1], :parameters) , name)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
include("idset.jl")
include("back.jl")
include("numeric.jl")
include("lib/real.jl")
include("lib/array.jl")
include("forward.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
"""
checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models.
"""
checkpoint(f, args...) = track(checkpoint, f, args...)
@grad function checkpoint(f, args...)
data(f(args...)), function (Δ)
y, back = forward(f, args...)
(nothing, back(Δ)...)
end
end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
@grad identity(x) = data(x), Δ -> (Δ,)
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import Adapt: adapt, adapt_structure
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end

View File

@ -1,210 +0,0 @@
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
end
return
end
function scan(x)
istracked(x) && scan(tracker(x))
return
end
function back_(c::Call, Δ, once)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end
back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end
return
end
back(::Nothing, Δ, once) = return
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ, once)
return
end
function gradient_(f, xs...)
xs = param.(data.(xs))
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end
# Out-of-place gradients
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
function back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
return
end
back(::Grads, ::Nothing, _) = return
collectmemaybe(xs) = xs
function forward(f, ps::Params)
y = collectmemaybe(f())
y, function (Δ)
g = Grads(ps)
if istracked(y)
scan(y)
back(g, tracker(y), Δ)
end
return g
end
end
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(Ref(back(Δ)), args)
end
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
function gradient_nested(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)
# Jacobians and Hessians
import ..Flux
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,k,n)
for i = 1:k
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[i,:] = xp.grad
xp.grad .= 0 # Reset gradient accumulator
end
J
end
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)

View File

@ -1,53 +0,0 @@
using ForwardDiff
seed(x::Real, ::Val) = Dual(x, true)
function seed(x, ::Val{N}, offset = 0) where N
map(x, reshape(1:length(x), size(x))) do x, i
Dual(x, ntuple(j -> j+offset == i, Val(N)))
end
end
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
J = similar(xs, V, N, length(xs))
for i = 1:length(xs), j = 1:N
J[j, i] = xs[i].partials.values[j]
end
return map(x -> x.value, xs), J
end
function forward_jacobian(f, x, ::Val{N}) where N
y, _J = extract(f(seed(x, Val(N))))
J = similar(_J, length(x), length(y))
J[1:N,:] = _J
offset = 0
while offset + N < length(x)
offset += N
_, _J = extract(f(seed(x, Val(N), offset)))
range = (1+offset):min(N+offset,length(x))
J[range,:] = @view _J[range.-offset,:]
end
return y, J
end
function forward_jacobian(f, x)
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
forward_jacobian(f, x, Val(length(x)))
else
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
end
end
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
vec_scalar(x) = vec(x)
vec_scalar(x::Real) = [x]
reshape_scalar(x, y) = reshape(y, size(x))
reshape_scalar(x::Real, y) = y[]
@grad function forwarddiff(f, x)
y, J = forward_jacobian(f, data(x))
return y, -> (nothing, reshape_scalar(x, J*vec_scalar()))
end

View File

@ -1,28 +0,0 @@
struct IdSet{T} <: AbstractSet{T}
dict::IdDict{T,Nothing}
IdSet{T}() where T = new(IdDict{T,Nothing}())
end
Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}()
Base.push!(s::IdSet) = s
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
IdSet(xs) = IdSet{eltype(xs)}(xs)
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
@forward IdSet.dict Base.length
function Base.iterate(v::IdSet, state...)
y = Base.iterate(keys(v.dict), state...)
y === nothing && return nothing
return (y[1], y[2])
end

View File

@ -1,516 +0,0 @@
import Base: *
import LinearAlgebra
import LinearAlgebra: inv, det, logdet, logabsdet, \, /
using Statistics
using LinearAlgebra: Transpose, Adjoint, diagm, diag
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
tracker::Tracked{A}
data::A
grad::A
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end
data(x::TrackedArray) = x.data
tracker(x::TrackedArray) = x.tracker
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
Base.convert(::Type{<:TrackedArray}, x::TrackedArray) =
error("Not implemented: convert $(typeof(x)) to $T")
Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
TrackedArray(convert(A, x))
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
@isdefined(A) ?
print(io, "TrackedArray{…,$A}") :
invoke(show, Tuple{IO,DataType}, io, t)
function Base.summary(io::IO, x::TrackedArray)
print(io, "Tracked ")
summary(io, data(x))
end
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
function Base.show(io::IO, x::TrackedArray)
show(io, data(x))
print(io, " (tracked)")
end
Base.copy(x::TrackedArray) = x
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
function update!(x::TrackedArray, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
# Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
size(data(x), i, j, is...)
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
for op in [:(==), :≈]
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
end
# Array Stdlib
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′[i...] = data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
end
end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@grad function view(x::AbstractArray, inds...)
view(data(x), inds...), function (Δ)
grad_output = zero(x)
subgrad = view(grad_output, inds...)
subgrad[:] = data(Δ)
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,)
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
det(xs::TrackedArray) = track(det, xs)
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
logdet(xs::TrackedArray) = track(logdet, xs)
@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),)
logabsdet(xs::TrackedArray) = track(logabsdet, xs)
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
repeat(data(xs), inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs)
S = size(xs)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
(nobacksies(:repeat, Δ′),)
end
end
function combinations(xs, n)
n < 1 && return [[]]
cs = combinations(xs, n-1)
[[x, c...] for x in xs, c in cs]
end
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
end
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
track($f, $(cnames...), x, xs...)
end
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
track($f, $(cnames...), x, xs...)
end
@grad function vcat(xs...)
vcat(data.(xs)...), function (Δ)
start = 0
Δs = [begin
i = map(_ -> :, size(xsi)) |> Base.tail
d = Δ[start+1:start+size(xsi,1), i...]
start += size(xsi, 1)
d
end for xsi in xs]
return (Δs...,)
end
end
@grad function hcat(xs...)
hcat(data.(xs)...), function (Δ)
start = 0
Δs = [begin
d = if ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Δ[:, start+1:start+size(xsi,2), i...]
end
start += size(xsi, 2)
d
end for xsi in xs]
return (Δs...,)
end
end
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
cnames = map(_ -> gensym(), c)
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
track(cat, $(cnames...), x, xs..., dims = dims)
end
@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = [begin
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
d
end for xs in Xs]
return (Δs...,)
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
mat1_rsh = reshape(mat1,(1,m1,1,n1))
m2, n2 = size(mat2)
mat2_rsh = reshape(mat2,(m2,1,n2,1))
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
end
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
inv(A::TrackedArray) = Tracker.track(inv, A)
@grad function inv(A)
return inv(Tracker.data(A)), function (Δ)
Ainv = inv(A)
∇A = - Ainv' * Δ * Ainv'
return (∇A, )
end
end
# (/) rdivide
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
@grad function (A / B)
return Tracker.data(A) / Tracker.data(B), function (Δ)
Binv = inv(B)
∇B = - Binv' * A' * Δ * Binv'
return (Δ * Binv', ∇B)
end
end
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
@grad function (A \ B)
return Tracker.data(A) \ Tracker.data(B), function (Δ)
Ainv = inv(A)
∇A = - Ainv' * Δ * B' * Ainv'
return (∇A, Ainv' * Δ)
end
end
# Reductions
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
Δ -> (zero(xs) .+ Δ, )
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
@grad prod(xs, dim) = prod(data(xs), dims = dim),
Δ -> (nobacksies(:sum,
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
nothing)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
import LinearAlgebra: dot
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
# Hacks to get std working
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected)
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
@grad function maximum(xs; dims = dims)
maximum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
_, i = findmax(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:maximum, Δ′),)
end
end
@grad function minimum(xs; dims = dims)
minimum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
_, i = findmin(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:minimum, Δ′),)
end
end
# BLAS
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
x::TrackedVector * y::AbstractVector = track(*, x, y)
x::AbstractVector * y::TrackedVector = track(*, x, y)
x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs)
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
@grad depthwiseconv(x, w; kw...) =
depthwiseconv(data(x), data(w); kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@grad conv(x, w; kw...) =
conv(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...),
NNlib.∇conv_filter(data.((Δ, x))...; size=size(w), kw...)))
∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)
@grad ∇conv_data(x, w; kw...) =
∇conv_data(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.conv(data.((Δ, w))...; size=size(x), kw...),
NNlib.∇conv_filter(data.((x, Δ))...; size=size(w), kw...)))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
@grad function maxpool(x, k; kw...)
y = maxpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
@grad function meanpool(x, k; kw...)
y = meanpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
# Broadcasting
using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue, _) = nothing
dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
return Δ * f(dargs...).partials[1]
end
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
y = broadcast(f, data.(args)...)
eltype(y) <: Real || return y
eltype(y) == Bool && return y
function back(Δ)
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
dxs = map(unbroadcast, args, Δargs)
return dxs
end
# So we can return non-tracked arrays
track(Call(back, tracker.(args)), y)
end
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
struct TrackedStyle <: BroadcastStyle end
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
# We have to re-build the original broadcast struct to get the appropriate array
# style. We need this primarily to support CuArrays' broadcasting fixes.
broadcast_rebuild(xs) = data(xs)
broadcast_rebuild(bc::Broadcasted) =
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
preprocess(x) = x
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
bc1 = Broadcast.flatten(bc)
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
∇broadcast(bc2.f, bc1.args...)
end
using Requires
# https://github.com/FluxML/Flux.jl/issues/353
if VERSION < v"1.1.0-DEV.548"
@init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
let makeargs = make_makeargs(bc), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
end
end
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs(args...)
a, b = headargs(args1...), tailargs(args1...)
(f(a...), b...)
end
end
end
end
end
end

View File

@ -1,160 +0,0 @@
mutable struct TrackedReal{T<:Real} <: Real
data::T
tracker::Tracked{T}
end
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
data(x::TrackedReal) = x.data
tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1, once = once)
end
function update!(x::TrackedReal, Δ)
x.data += data(Δ)
tracker(x).grad = 0
return x
end
function Base.show(io::IO, x::TrackedReal)
T = get(io, :typeinfo, Any)
show(io, data(x))
T <: TrackedReal || print(io, " (tracked)")
end
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.copy(x::TrackedReal) = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
(T::Type{<:TrackedReal})(x::Real) = convert(T, x)
for op in [:(==), :≈, :<, :(<=)]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
Base.eps(x::TrackedReal) = eps(data(x))
Base.eps(::Type{TrackedReal{T}}) where T = eps(T)
for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end
Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
Base.float(x::TrackedReal) = x
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)}
using Random
for f in :[rand, randn, randexp].args
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
end
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
@eval begin
@grad $M.$f(a::Real) =
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
$M.$f(a::TrackedReal) = track($M.$f, a)
end
end
# Work around zero(π) not working, for some reason
_zero(::Irrational) = nothing
_zero(x) = zero(x)
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :a, :b)
f = :($M.$f)
@eval begin
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b)
$f(a::Real, b::TrackedReal) = track($f, a, b)
end
end
# Eliminating ambiguity
import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b)
# Hack for conversions
using ForwardDiff: Dual
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
(Dual{T,V,N})(x::Dual) where {T,V,N} = invoke(Dual{T,V,N}, Tuple{Number}, x)
# Tuples
struct TrackedTuple{T<:Tuple}
data::T
tracker::Tracked{T}
end
data(xs::TrackedTuple) = xs.data
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
print(io, " (tracked)")
end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, (tracker.(xs),)), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back_(c::Call{typeof(collect)}, Δ, once)
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
end
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)

View File

@ -1,18 +0,0 @@
function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp
Δ[i] = (y2-y1)/δ
end
return grads
end
gradcheck(f, xs...) =
all(isapprox.(ngradient(f, xs...),
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))

View File

@ -47,5 +47,7 @@ end
if CuArrays.libcudnn != nothing
@info "Testing Flux/CUDNN"
include("cudnn.jl")
include("curnn.jl")
if !haskey(ENV, "CI_DISABLE_CURNN_TEST")
include("curnn.jl")
end
end

View File

@ -14,3 +14,9 @@ using Test
@test FashionMNIST.labels() isa Vector{Int64}
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
@test Iris.features() isa Matrix
@test size(Iris.features()) == (4,150)
@test Iris.labels() isa Vector{String}
@test size(Iris.labels()) == (150,)

View File

@ -1,33 +1,75 @@
using Test, Random
import Flux: activations
@testset "basic" begin
@testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
# numeric test should be put into testset of corresponding layer
@testset "helpers" begin
@testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10)
@test activations(Chain(), x) == []
@test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
end
end
@testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
# numeric test should be put into testset of corresponding layer
end
@testset "Dense" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end
@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end
@testset "Maxout" begin
# Note that the normal common usage of Maxout is as per the docstring
# These are abnormal constructors used for testing purposes
@testset "Constructor" begin
mo = Maxout(() -> identity, 4)
input = rand(40)
@test mo(input) == input
end
@testset "Dense" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@testset "simple alternatives" begin
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
input = rand(40)
@test mo(input) == 2*input
end
@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
@testset "complex alternatives" begin
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
input = [3.0 2.0]
target = [0.5, 0.7].*input
@test mo(input) == target
end
@testset "params" begin
mo = Maxout(()->Dense(32, 64), 4)
ps = params(mo)
@test length(ps) == 8 #4 alts, each with weight and bias
end
end
end

View File

@ -4,9 +4,9 @@ using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
@test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2))
@test mp(x) == meanpool(x, PoolDims(x, 2))
end
@testset "CNN" begin
@ -22,14 +22,42 @@ end
@test size(m(r)) == (10, 5)
end
@testset "asymmetric padding" begin
r = ones(Float32, 28, 28, 1, 1)
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
m.weight.data[:] .= 1.0
m.bias.data[:] .= 0.0
y_hat = Flux.data(m(r))[:,:,1,1]
@test size(y_hat) == (27, 29)
@test y_hat[1, 1] 6.0
@test y_hat[2, 2] 9.0
@test y_hat[end, 1] 4.0
@test y_hat[1, end] 3.0
@test y_hat[1, end-1] 6.0
@test y_hat[end, end] 2.0
end
@testset "Depthwise Conv" begin
r = zeros(Float32, 28, 28, 3, 5)
m1 = DepthwiseConv((2, 2), 3=>5)
@test size(m1(r), 3) == 15
m2 = DepthwiseConv((2, 2), 3)
@test size(m2(r), 3) == 3
x = zeros(Float64, 28, 28, 3, 5)
m3 = DepthwiseConv((2, 2), 3 => 5)
@test size(m3(r), 3) == 15
m4 = DepthwiseConv((2, 2), 3)
@test size(m4(r), 3) == 3
end
@testset "ConvTranspose" begin
x = zeros(Float32, 28, 28, 1, 1)
y = Conv((3,3), 1 => 1)(x)
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
@test size(x_hat) == size(x)
end

View File

@ -104,3 +104,210 @@ end
@test (@allocated m(x)) < 100_000_000
end
end
@testset "InstanceNorm" begin
# helper functions
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
@test m.active
m(x)
#julia> x
#[:, :, 1] =
# 1.0 4.0
# 2.0 5.0
# 3.0 6.0
#
#[:, :, 2] =
# 7.0 10.0
# 8.0 11.0
# 9.0 12.0
#
# μ will be
# (1. + 2. + 3.) / 3 = 2.
# (4. + 5. + 6.) / 3 = 5.
#
# (7. + 8. + 9.) / 3 = 8.
# (10. + 11. + 12.) / 3 = 11.
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
@test m.μ [0.5, 0.8]
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
# 2-element Array{Float64,1}:
# 1.
# 1.
@test m.σ² reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
testmode!(m)
@test !m.active
x = m(x).data
@test isapprox(x[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
affine_shape = collect(sizes)
affine_shape[1] = 1
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x).data
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
end
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes
end
# show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
end
end
@testset "GroupNorm" begin
# begin tests
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m.β.data == [0, 0, 0, 0] # initβ(32)
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
@test m.active
m(x)
#julia> x
#[:, :, 1] =
# 1.0 4.0 7.0 10.0
# 2.0 5.0 8.0 11.0
# 3.0 6.0 9.0 12.0
#
#[:, :, 2] =
# 13.0 16.0 19.0 22.0
# 14.0 17.0 20.0 23.0
# 15.0 18.0 21.0 24.0
#
# μ will be
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
#
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
#
# μ =
# 3.5 15.5
# 9.5 21.5
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
@test m.μ [0.95, 1.55]
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
# 2-element Array{Tracker.TrackedReal{Float64},1}:
# 1.25
# 1.25
@test m.σ² mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
testmode!(m)
@test !m.active
x = m(x).data
println(x[1])
@test isapprox(x[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
μ_affine_shape = ones(Int,length(sizes) + 1)
μ_affine_shape[end-1] = 2 # Number of groups
affine_shape = ones(Int,length(sizes) + 1)
affine_shape[end-2] = 2 # Channels per group
affine_shape[end-1] = 2 # Number of groups
affine_shape[1] = sizes[1]
affine_shape[end] = sizes[end]
og_shape = size(x)
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x)
x_ = reshape(x,affine_shape...)
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
@test isapprox(y, out, atol = 1.0e-7)
end
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1)
@test size(y) == sizes
end
# show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x)
end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x)
end
end

View File

@ -4,21 +4,15 @@ using Flux.Tracker
using Test
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5
l = loss(rand(10))
back!(l)
delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta
θ = Params([w])
θ̄ = gradient(() -> loss(rand(10)), θ)
Optimise.update!(opt, θ, θ̄)
end
@test Flux.mse(w, w) < 0.01
end

View File

@ -1,347 +1,15 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
using NNlib: conv, ∇conv_data, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
using Statistics: mean, std
using Random
# using StatsBase
using Flux, Test
using Tracker: gradcheck
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
@test gradtest(x -> sum(x), randn(Float64,2,3))
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
@testset "Tracker" begin
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5))
@test gradtest(det, (4, 4))
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
r2 = f(A, param(B), C)
r3 = f(A, B, param(C))
r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
@test r0 == Flux.data(r4)
end
@testset "concat" begin
cat1(x...) = cat(x..., dims = 1)
cat2(x...) = cat(x..., dims = 2)
@testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8))
@test gradtest(vcatf, rand(5)', rand(5)')
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
@test gradtest(vcatf, rand(5), rand(3,1))
@test gradtest(vcatf, rand(5)', rand(2,5))
end
@testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
@test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2))
end
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
end
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(x..., dims = dim)
@test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
end
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(rand(2), dims=1), TrackedArray)
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
promotiontest(fcat, rand(2), rand(2), rand(2))
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
end
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
@testset "scalars" begin
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
@test vcat(1, param([1, 2, 3])) isa TrackedArray
@test hcat(1, param([1 2 3;])) isa TrackedArray
@test vcat(param(1), 2) isa TrackedArray
end
end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(x -> diagm(0 => x), rand(3))
@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
@testset "mean" begin
@test gradtest(mean, rand(2, 3))
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
end
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
end
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, dims = 1), rand(5,5))
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
@test gradtest(x -> Flux.normalise(x), rand(4,3))
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))
@test gradtest(norm, rand(5))
@test gradtest(rand(5)) do x
y = x.^2
2y + x
end
@test gradtest(conv, rand(10, 3, 2), randn(Float64, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2))
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test gradtest(x -> Float64.(x), 5)
@testset "equality & order" begin
# TrackedReal
@test param(2)^2 == param(4)
@test param(2)^2 == 4
@test 4 == param(2)^2
@test param(2)^2 param(4)
@test param(2)^2 4
@test 4 param(2)^2
@test (param([1,2,3]) .< 2) == [true, false, false]
@test (param([1,2,3]) .<= 2) == [true, true, false]
@test (2 .> param([1,2,3])) == [true, false, false]
@test (2 .>= param([1,2,3])) == [true, true, false]
# TrackedArray
@test param([1,2,3]).^2 == param([1,4,9])
@test [1,2,3].^2 == param([1,4,9])
@test param([1,2,3]).^2 == [1,4,9]
@test param([1,2,3]).^2 param([1,4,9])
@test [1,2,3].^2 param([1,4,9])
@test param([1,2,3]).^2 [1,4,9]
end
@testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2)
@test x isa TrackedArray
@test size(x) == (4,2)
x = reshape(param([1]), (1,:))
@test x isa TrackedArray
@test size(x) == (1,1)
x = reshape(param(rand(2)), (2,:))
@test x isa TrackedArray
@test size(x) == (2,1)
x = reshape(param(rand(2,2)), (1,:,2))
@test x isa TrackedArray
@test size(x) == (1,2,2)
end
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l, once = false)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l, once = false)
@test x.grad == [8]
end
@testset "Fallbacks" begin
xs = param([1 2; 3 4])
@test similar(xs) isa Matrix{Float64}
end
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
b = param(rand())
Tracker.back!(b)
@test Tracker.grad(b) == 1
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
@test gradient(2, 3) do x, y
xy = Tracker.collect([x, y])
xy[1]*xy[2]
end == (3, 2)
end
# Gradient Hooks
@testset "Hooks" begin
x = param(2)
y = Tracker.hook(-, x)
back!(y)
@test grad(x) == -1
end
@testset "Checkpointing" begin
count = 0
function mul(a, b)
count += 1
a * b
end
@test gradient(x -> mul(5, x), 3)[1] == 5
@test count == 1
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
@test count == 3
end
@testset "Updates" begin
xs = param([1, 2, 3])
Tracker.update!(xs, param([4, 5, 6]))
@test xs == [5, 7, 9]
x = param(3)
Tracker.update!(x, param(4))
@test x == 7
end
@testset "Params" begin
W = param(randn(5, 10))
x = rand(10)
dW = gradient(W -> sum(W*x), W)[1]
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
@test gs[W] == dW
end
@testset "Forward" begin
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
reshape(ones(25), :, 1)
@test gradient([2, 3]) do x
forwarddiff(x) do x
x[1]*x[2]
end
end == ([3, 2],)
end
@testset "Custom Sensitivities" begin
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
@test back([1, 1]) == (32,)
end
end #testset

View File

@ -87,6 +87,12 @@ end
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end
@testset "Basic Stacking" begin
x = randn(3,3)
stacked = stack([x, x], 2)
@test size(stacked) == (3,2,3)
end
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)