Merge branch 'master' into onecold
This commit is contained in:
commit
9bbbd17e4b
37
.gitlab-ci.yml
Normal file
37
.gitlab-ci.yml
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
before_script:
|
||||||
|
- export CI_DISABLE_CURNN_TEST=true
|
||||||
|
|
||||||
|
variables:
|
||||||
|
CI_IMAGE_TAG: 'cuda'
|
||||||
|
|
||||||
|
include:
|
||||||
|
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml'
|
||||||
|
|
||||||
|
.flux:
|
||||||
|
extends: .test
|
||||||
|
script:
|
||||||
|
- julia -e 'using InteractiveUtils;
|
||||||
|
versioninfo()'
|
||||||
|
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
|
||||||
|
- julia -e 'using Pkg;
|
||||||
|
Pkg.add("CuArrays");'
|
||||||
|
- julia --project -e 'using Pkg;
|
||||||
|
Pkg.instantiate();
|
||||||
|
Pkg.build();
|
||||||
|
Pkg.test(; coverage=true);'
|
||||||
|
|
||||||
|
test:v1.0:
|
||||||
|
extends: .flux
|
||||||
|
variables:
|
||||||
|
CI_VERSION_TAG: 'v1.0'
|
||||||
|
only:
|
||||||
|
- staging
|
||||||
|
- trying
|
||||||
|
|
||||||
|
test:v1.1:
|
||||||
|
extends: .flux
|
||||||
|
variables:
|
||||||
|
CI_VERSION_TAG: 'v1.1'
|
||||||
|
only:
|
||||||
|
- staging
|
||||||
|
- trying
|
@ -1,6 +1,6 @@
|
|||||||
The Flux.jl package is licensed under the MIT "Expat" License:
|
The Flux.jl package is licensed under the MIT "Expat" License:
|
||||||
|
|
||||||
> Copyright (c) 2016: Mike Innes.
|
> Copyright (c) 2016-19: Julia Computing, INc., Mike Innes and Contributors
|
||||||
>
|
>
|
||||||
> Permission is hereby granted, free of charge, to any person obtaining
|
> Permission is hereby granted, free of charge, to any person obtaining
|
||||||
> a copy of this software and associated documentation files (the
|
> a copy of this software and associated documentation files (the
|
||||||
|
100
Manifest.toml
100
Manifest.toml
@ -1,3 +1,5 @@
|
|||||||
|
# This file is machine-generated - editing it directly is not advised
|
||||||
|
|
||||||
[[AbstractTrees]]
|
[[AbstractTrees]]
|
||||||
deps = ["Markdown", "Test"]
|
deps = ["Markdown", "Test"]
|
||||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||||
@ -25,11 +27,17 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
|||||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||||
version = "0.5.3"
|
version = "0.5.3"
|
||||||
|
|
||||||
|
[[CSTParser]]
|
||||||
|
deps = ["LibGit2", "Test", "Tokenize"]
|
||||||
|
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
|
||||||
|
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
||||||
|
version = "0.5.2"
|
||||||
|
|
||||||
[[CodecZlib]]
|
[[CodecZlib]]
|
||||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||||
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
|
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
||||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||||
version = "0.5.1"
|
version = "0.5.2"
|
||||||
|
|
||||||
[[ColorTypes]]
|
[[ColorTypes]]
|
||||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||||
@ -51,9 +59,15 @@ version = "0.2.0"
|
|||||||
|
|
||||||
[[Compat]]
|
[[Compat]]
|
||||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||||
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
|
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
||||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||||
version = "1.5.1"
|
version = "2.1.0"
|
||||||
|
|
||||||
|
[[Crayons]]
|
||||||
|
deps = ["Test"]
|
||||||
|
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
||||||
|
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||||
|
version = "4.0.0"
|
||||||
|
|
||||||
[[DataStructures]]
|
[[DataStructures]]
|
||||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||||
@ -71,18 +85,18 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
|||||||
|
|
||||||
[[DiffResults]]
|
[[DiffResults]]
|
||||||
deps = ["Compat", "StaticArrays"]
|
deps = ["Compat", "StaticArrays"]
|
||||||
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
|
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
|
||||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||||
version = "0.0.3"
|
version = "0.0.4"
|
||||||
|
|
||||||
[[DiffRules]]
|
[[DiffRules]]
|
||||||
deps = ["Random", "Test"]
|
deps = ["Random", "Test"]
|
||||||
git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9"
|
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
||||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||||
version = "0.0.8"
|
version = "0.0.10"
|
||||||
|
|
||||||
[[Distributed]]
|
[[Distributed]]
|
||||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
deps = ["Random", "Serialization", "Sockets"]
|
||||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||||
|
|
||||||
[[FixedPointNumbers]]
|
[[FixedPointNumbers]]
|
||||||
@ -93,19 +107,19 @@ version = "0.5.3"
|
|||||||
|
|
||||||
[[ForwardDiff]]
|
[[ForwardDiff]]
|
||||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||||
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
|
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
||||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||||
version = "0.10.2"
|
version = "0.10.3"
|
||||||
|
|
||||||
[[InteractiveUtils]]
|
[[InteractiveUtils]]
|
||||||
deps = ["LinearAlgebra", "Markdown"]
|
deps = ["Markdown"]
|
||||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
|
|
||||||
[[Juno]]
|
[[Juno]]
|
||||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||||
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
|
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
|
||||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||||
version = "0.5.4"
|
version = "0.7.0"
|
||||||
|
|
||||||
[[LibGit2]]
|
[[LibGit2]]
|
||||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||||
@ -121,10 +135,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
|||||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||||
|
|
||||||
[[MacroTools]]
|
[[MacroTools]]
|
||||||
deps = ["Compat"]
|
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
|
||||||
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
|
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
|
||||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||||
version = "0.4.4"
|
version = "0.5.0"
|
||||||
|
|
||||||
[[Markdown]]
|
[[Markdown]]
|
||||||
deps = ["Base64"]
|
deps = ["Base64"]
|
||||||
@ -146,12 +160,10 @@ version = "0.4.0"
|
|||||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||||
|
|
||||||
[[NNlib]]
|
[[NNlib]]
|
||||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
|
||||||
git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f"
|
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
|
||||||
repo-rev = "master"
|
|
||||||
repo-url = "https://github.com/FluxML/NNlib.jl.git"
|
|
||||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
version = "0.4.3+"
|
version = "0.6.0"
|
||||||
|
|
||||||
[[NaNMath]]
|
[[NaNMath]]
|
||||||
deps = ["Compat"]
|
deps = ["Compat"]
|
||||||
@ -161,9 +173,9 @@ version = "0.3.2"
|
|||||||
|
|
||||||
[[OrderedCollections]]
|
[[OrderedCollections]]
|
||||||
deps = ["Random", "Serialization", "Test"]
|
deps = ["Random", "Serialization", "Test"]
|
||||||
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
|
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
|
||||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||||
version = "1.0.2"
|
version = "1.1.0"
|
||||||
|
|
||||||
[[Pkg]]
|
[[Pkg]]
|
||||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||||
@ -228,29 +240,47 @@ version = "0.7.2"
|
|||||||
|
|
||||||
[[StaticArrays]]
|
[[StaticArrays]]
|
||||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||||
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
|
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
|
||||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
version = "0.10.2"
|
version = "0.10.3"
|
||||||
|
|
||||||
[[Statistics]]
|
[[Statistics]]
|
||||||
deps = ["LinearAlgebra", "SparseArrays"]
|
deps = ["LinearAlgebra", "SparseArrays"]
|
||||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
|
|
||||||
[[StatsBase]]
|
[[StatsBase]]
|
||||||
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||||
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
|
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
|
||||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||||
version = "0.27.0"
|
version = "0.30.0"
|
||||||
|
|
||||||
[[Test]]
|
[[Test]]
|
||||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
|
||||||
|
[[TimerOutputs]]
|
||||||
|
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
||||||
|
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
||||||
|
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||||
|
version = "0.5.0"
|
||||||
|
|
||||||
|
[[Tokenize]]
|
||||||
|
deps = ["Printf", "Test"]
|
||||||
|
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
|
||||||
|
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
||||||
|
version = "0.5.3"
|
||||||
|
|
||||||
|
[[Tracker]]
|
||||||
|
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||||
|
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
|
||||||
|
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||||
|
version = "0.2.0"
|
||||||
|
|
||||||
[[TranscodingStreams]]
|
[[TranscodingStreams]]
|
||||||
deps = ["Pkg", "Random", "Test"]
|
deps = ["Random", "Test"]
|
||||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
|
||||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||||
version = "0.8.1"
|
version = "0.9.4"
|
||||||
|
|
||||||
[[URIParser]]
|
[[URIParser]]
|
||||||
deps = ["Test", "Unicode"]
|
deps = ["Test", "Unicode"]
|
||||||
@ -259,7 +289,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
|||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
|
||||||
[[UUIDs]]
|
[[UUIDs]]
|
||||||
deps = ["Random"]
|
deps = ["Random", "SHA"]
|
||||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||||
|
|
||||||
[[Unicode]]
|
[[Unicode]]
|
||||||
@ -267,6 +297,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
|||||||
|
|
||||||
[[ZipFile]]
|
[[ZipFile]]
|
||||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||||
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
|
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
||||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
version = "0.8.0"
|
version = "0.8.1"
|
||||||
|
25
NEWS.md
Normal file
25
NEWS.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# v0.8.0
|
||||||
|
|
||||||
|
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
|
||||||
|
* New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647)
|
||||||
|
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
|
||||||
|
* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/).
|
||||||
|
* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592)
|
||||||
|
* Many docs and bugfixes thanks to @KristofferC and others.
|
||||||
|
* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`.
|
||||||
|
* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615).
|
||||||
|
* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs.
|
||||||
|
* New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656).
|
||||||
|
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
|
||||||
|
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||||
|
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||||
|
|
||||||
|
AD Changes:
|
||||||
|
|
||||||
|
* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files).
|
||||||
|
* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576).
|
||||||
|
* Flux.Tracker is now its [own package](https://github.com/FluxML/Tracker.jl), in preparation for replacing it with Zygote.
|
||||||
|
|
||||||
|
# v0.7.0
|
||||||
|
|
||||||
|
Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.
|
18
Project.toml
18
Project.toml
@ -6,21 +6,29 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
|||||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||||
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
|
|
||||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
|
||||||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||||
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
|
|
||||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
|
|
||||||
|
[compat]
|
||||||
|
NNlib = "0.6"
|
||||||
|
Tracker = "0.2"
|
||||||
|
julia = "0.7, 1"
|
||||||
|
|
||||||
|
[extras]
|
||||||
|
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
|
||||||
|
[targets]
|
||||||
|
test = ["Test"]
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
|
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/) [](https://doi.org/10.21105/joss.00602)
|
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/) [](https://doi.org/10.21105/joss.00602)
|
||||||
|
|
||||||
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, a
|
|||||||
julia> Pkg.add("Flux")
|
julia> Pkg.add("Flux")
|
||||||
```
|
```
|
||||||
|
|
||||||
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
See the [documentation](https://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||||
|
|
||||||
If you use Flux in research, please cite the following paper:
|
If you use Flux in research, please cite the following paper:
|
||||||
|
|
||||||
|
7
REQUIRE
7
REQUIRE
@ -10,9 +10,4 @@ ZipFile
|
|||||||
AbstractTrees
|
AbstractTrees
|
||||||
Reexport
|
Reexport
|
||||||
StatsBase
|
StatsBase
|
||||||
|
Tracker
|
||||||
# AD
|
|
||||||
ForwardDiff 0.5.0
|
|
||||||
DiffRules
|
|
||||||
SpecialFunctions
|
|
||||||
NaNMath
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# This file is machine-generated - editing it directly is not advised
|
||||||
|
|
||||||
[[AbstractTrees]]
|
[[AbstractTrees]]
|
||||||
deps = ["Markdown", "Test"]
|
deps = ["Markdown", "Test"]
|
||||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||||
@ -6,9 +8,9 @@ version = "0.2.1"
|
|||||||
|
|
||||||
[[Adapt]]
|
[[Adapt]]
|
||||||
deps = ["LinearAlgebra", "Test"]
|
deps = ["LinearAlgebra", "Test"]
|
||||||
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
|
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
|
||||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
|
|
||||||
[[Base64]]
|
[[Base64]]
|
||||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||||
@ -25,11 +27,17 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
|||||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||||
version = "0.5.3"
|
version = "0.5.3"
|
||||||
|
|
||||||
|
[[CSTParser]]
|
||||||
|
deps = ["LibGit2", "Test", "Tokenize"]
|
||||||
|
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
|
||||||
|
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
||||||
|
version = "0.5.2"
|
||||||
|
|
||||||
[[CodecZlib]]
|
[[CodecZlib]]
|
||||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||||
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
|
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
||||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||||
version = "0.5.1"
|
version = "0.5.2"
|
||||||
|
|
||||||
[[ColorTypes]]
|
[[ColorTypes]]
|
||||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||||
@ -51,9 +59,15 @@ version = "0.2.0"
|
|||||||
|
|
||||||
[[Compat]]
|
[[Compat]]
|
||||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||||
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
|
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
||||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||||
version = "1.4.0"
|
version = "2.1.0"
|
||||||
|
|
||||||
|
[[Crayons]]
|
||||||
|
deps = ["Test"]
|
||||||
|
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
||||||
|
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||||
|
version = "4.0.0"
|
||||||
|
|
||||||
[[DataStructures]]
|
[[DataStructures]]
|
||||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||||
@ -71,31 +85,31 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
|||||||
|
|
||||||
[[DiffResults]]
|
[[DiffResults]]
|
||||||
deps = ["Compat", "StaticArrays"]
|
deps = ["Compat", "StaticArrays"]
|
||||||
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
|
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
|
||||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||||
version = "0.0.3"
|
version = "0.0.4"
|
||||||
|
|
||||||
[[DiffRules]]
|
[[DiffRules]]
|
||||||
deps = ["Random", "Test"]
|
deps = ["Random", "Test"]
|
||||||
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
|
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
||||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||||
version = "0.0.7"
|
version = "0.0.10"
|
||||||
|
|
||||||
[[Distributed]]
|
[[Distributed]]
|
||||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
deps = ["Random", "Serialization", "Sockets"]
|
||||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||||
|
|
||||||
[[DocStringExtensions]]
|
[[DocStringExtensions]]
|
||||||
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
|
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
|
||||||
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7"
|
git-tree-sha1 = "4d30e889c9f106a51ffa4791a88ffd4765bf20c3"
|
||||||
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
|
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
|
||||||
version = "0.6.0"
|
version = "0.7.0"
|
||||||
|
|
||||||
[[Documenter]]
|
[[Documenter]]
|
||||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
||||||
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
|
git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0"
|
||||||
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
version = "0.21.0"
|
version = "0.22.3"
|
||||||
|
|
||||||
[[FixedPointNumbers]]
|
[[FixedPointNumbers]]
|
||||||
deps = ["Test"]
|
deps = ["Test"]
|
||||||
@ -104,26 +118,32 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
|||||||
version = "0.5.3"
|
version = "0.5.3"
|
||||||
|
|
||||||
[[Flux]]
|
[[Flux]]
|
||||||
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
|
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"]
|
||||||
path = ".."
|
path = ".."
|
||||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
version = "0.6.10+"
|
version = "0.8.2+"
|
||||||
|
|
||||||
[[ForwardDiff]]
|
[[ForwardDiff]]
|
||||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||||
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
|
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
||||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||||
version = "0.10.1"
|
version = "0.10.3"
|
||||||
|
|
||||||
[[InteractiveUtils]]
|
[[InteractiveUtils]]
|
||||||
deps = ["LinearAlgebra", "Markdown"]
|
deps = ["Markdown"]
|
||||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
|
|
||||||
|
[[JSON]]
|
||||||
|
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
|
||||||
|
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
|
||||||
|
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
|
||||||
|
version = "0.20.0"
|
||||||
|
|
||||||
[[Juno]]
|
[[Juno]]
|
||||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||||
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
|
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
|
||||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||||
version = "0.5.3"
|
version = "0.7.0"
|
||||||
|
|
||||||
[[LibGit2]]
|
[[LibGit2]]
|
||||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||||
@ -139,10 +159,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
|||||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||||
|
|
||||||
[[MacroTools]]
|
[[MacroTools]]
|
||||||
deps = ["Compat"]
|
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
|
||||||
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
|
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
|
||||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||||
version = "0.4.4"
|
version = "0.5.0"
|
||||||
|
|
||||||
[[Markdown]]
|
[[Markdown]]
|
||||||
deps = ["Base64"]
|
deps = ["Base64"]
|
||||||
@ -156,18 +176,18 @@ version = "0.5.0"
|
|||||||
|
|
||||||
[[Missings]]
|
[[Missings]]
|
||||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||||
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
|
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
|
||||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||||
version = "0.3.1"
|
version = "0.4.0"
|
||||||
|
|
||||||
[[Mmap]]
|
[[Mmap]]
|
||||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||||
|
|
||||||
[[NNlib]]
|
[[NNlib]]
|
||||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
|
||||||
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
|
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
|
||||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
version = "0.4.3"
|
version = "0.6.0"
|
||||||
|
|
||||||
[[NaNMath]]
|
[[NaNMath]]
|
||||||
deps = ["Compat"]
|
deps = ["Compat"]
|
||||||
@ -177,9 +197,9 @@ version = "0.3.2"
|
|||||||
|
|
||||||
[[OrderedCollections]]
|
[[OrderedCollections]]
|
||||||
deps = ["Random", "Serialization", "Test"]
|
deps = ["Random", "Serialization", "Test"]
|
||||||
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
|
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
|
||||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||||
version = "1.0.2"
|
version = "1.1.0"
|
||||||
|
|
||||||
[[Pkg]]
|
[[Pkg]]
|
||||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||||
@ -244,29 +264,47 @@ version = "0.7.2"
|
|||||||
|
|
||||||
[[StaticArrays]]
|
[[StaticArrays]]
|
||||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||||
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
|
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
|
||||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
version = "0.10.2"
|
version = "0.10.3"
|
||||||
|
|
||||||
[[Statistics]]
|
[[Statistics]]
|
||||||
deps = ["LinearAlgebra", "SparseArrays"]
|
deps = ["LinearAlgebra", "SparseArrays"]
|
||||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
|
|
||||||
[[StatsBase]]
|
[[StatsBase]]
|
||||||
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||||
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
|
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
|
||||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||||
version = "0.27.0"
|
version = "0.30.0"
|
||||||
|
|
||||||
[[Test]]
|
[[Test]]
|
||||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
|
||||||
|
[[TimerOutputs]]
|
||||||
|
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
||||||
|
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
||||||
|
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||||
|
version = "0.5.0"
|
||||||
|
|
||||||
|
[[Tokenize]]
|
||||||
|
deps = ["Printf", "Test"]
|
||||||
|
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
|
||||||
|
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
||||||
|
version = "0.5.3"
|
||||||
|
|
||||||
|
[[Tracker]]
|
||||||
|
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||||
|
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
|
||||||
|
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||||
|
version = "0.2.0"
|
||||||
|
|
||||||
[[TranscodingStreams]]
|
[[TranscodingStreams]]
|
||||||
deps = ["Pkg", "Random", "Test"]
|
deps = ["Random", "Test"]
|
||||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
|
||||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||||
version = "0.8.1"
|
version = "0.9.4"
|
||||||
|
|
||||||
[[URIParser]]
|
[[URIParser]]
|
||||||
deps = ["Test", "Unicode"]
|
deps = ["Test", "Unicode"]
|
||||||
@ -275,7 +313,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
|||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
|
||||||
[[UUIDs]]
|
[[UUIDs]]
|
||||||
deps = ["Random"]
|
deps = ["Random", "SHA"]
|
||||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||||
|
|
||||||
[[Unicode]]
|
[[Unicode]]
|
||||||
@ -283,6 +321,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
|||||||
|
|
||||||
[[ZipFile]]
|
[[ZipFile]]
|
||||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||||
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
|
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
||||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
version = "0.8.0"
|
version = "0.8.1"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
using Documenter, Flux, NNlib
|
using Documenter, Flux, NNlib
|
||||||
|
|
||||||
makedocs(modules=[Flux, NNlib],
|
makedocs(modules=[Flux, NNlib],
|
||||||
doctest = false,
|
doctest = true,
|
||||||
analytics = "UA-36890222-9",
|
analytics = "UA-36890222-9",
|
||||||
sitename = "Flux",
|
sitename = "Flux",
|
||||||
# Uncomment below for local build
|
# Uncomment below for local build
|
||||||
@ -19,6 +19,7 @@ makedocs(modules=[Flux, NNlib],
|
|||||||
"One-Hot Encoding" => "data/onehot.md",
|
"One-Hot Encoding" => "data/onehot.md",
|
||||||
"GPU Support" => "gpu.md",
|
"GPU Support" => "gpu.md",
|
||||||
"Saving & Loading" => "saving.md",
|
"Saving & Loading" => "saving.md",
|
||||||
|
"Performance Tips" => "performance.md",
|
||||||
"Internals" =>
|
"Internals" =>
|
||||||
["Backpropagation" => "internals/tracker.md"],
|
["Backpropagation" => "internals/tracker.md"],
|
||||||
"Community" => "community.md"])
|
"Community" => "community.md"])
|
||||||
|
@ -4,45 +4,53 @@
|
|||||||
|
|
||||||
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
||||||
|
|
||||||
```julia
|
```jldoctest basics
|
||||||
using Flux.Tracker
|
julia> using Flux.Tracker
|
||||||
|
|
||||||
f(x) = 3x^2 + 2x + 1
|
julia> f(x) = 3x^2 + 2x + 1;
|
||||||
|
|
||||||
# df/dx = 6x + 2
|
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
|
||||||
df(x) = Tracker.gradient(f, x; nest = true)[1]
|
|
||||||
|
|
||||||
df(2) # 14.0 (tracked)
|
julia> df(2)
|
||||||
|
14.0 (tracked)
|
||||||
|
|
||||||
# d²f/dx² = 6
|
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
|
||||||
d2f(x) = Tracker.gradient(df, x; nest = true)[1]
|
|
||||||
|
|
||||||
d2f(2) # 6.0 (tracked)
|
julia> d2f(2)
|
||||||
|
6.0 (tracked)
|
||||||
```
|
```
|
||||||
|
|
||||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||||
|
|
||||||
When a function has many parameters, we can pass them all in explicitly:
|
When a function has many parameters, we can pass them all in explicitly:
|
||||||
|
|
||||||
```julia
|
```jldoctest basics
|
||||||
f(W, b, x) = W * x + b
|
julia> f(W, b, x) = W * x + b;
|
||||||
|
|
||||||
Tracker.gradient(f, 2, 3, 4)
|
julia> Tracker.gradient(f, 2, 3, 4)
|
||||||
# (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
||||||
```
|
```
|
||||||
|
|
||||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
||||||
|
|
||||||
```julia
|
```jldoctest basics
|
||||||
W = param(2) # 2.0 (tracked)
|
julia> using Flux
|
||||||
b = param(3) # 3.0 (tracked)
|
|
||||||
|
|
||||||
f(x) = W * x + b
|
julia> W = param(2)
|
||||||
|
2.0 (tracked)
|
||||||
|
|
||||||
grads = Tracker.gradient(() -> f(4), params(W, b))
|
julia> b = param(3)
|
||||||
|
3.0 (tracked)
|
||||||
|
|
||||||
grads[W] # 4.0
|
julia> f(x) = W * x + b;
|
||||||
grads[b] # 1.0
|
|
||||||
|
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
|
||||||
|
|
||||||
|
julia> grads[W]
|
||||||
|
4.0 (tracked)
|
||||||
|
|
||||||
|
julia> grads[b]
|
||||||
|
1.0 (tracked)
|
||||||
```
|
```
|
||||||
|
|
||||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||||
|
@ -5,14 +5,16 @@ These core layers form the foundation of almost all neural networks.
|
|||||||
```@docs
|
```@docs
|
||||||
Chain
|
Chain
|
||||||
Dense
|
Dense
|
||||||
|
```
|
||||||
|
|
||||||
|
## Convolution and Pooling Layers
|
||||||
|
|
||||||
|
These layers are used to build convolutional neural networks (CNNs).
|
||||||
|
|
||||||
|
```@docs
|
||||||
Conv
|
Conv
|
||||||
MaxPool
|
MaxPool
|
||||||
MeanPool
|
MeanPool
|
||||||
```
|
|
||||||
|
|
||||||
## Additional Convolution Layers
|
|
||||||
|
|
||||||
```@docs
|
|
||||||
DepthwiseConv
|
DepthwiseConv
|
||||||
ConvTranspose
|
ConvTranspose
|
||||||
```
|
```
|
||||||
@ -28,6 +30,25 @@ GRU
|
|||||||
Flux.Recur
|
Flux.Recur
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Other General Purpose Layers
|
||||||
|
These are marginally more obscure than the Basic Layers.
|
||||||
|
But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs).
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Maxout
|
||||||
|
```
|
||||||
|
|
||||||
|
# Normalisation & Regularisation
|
||||||
|
|
||||||
|
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.testmode!
|
||||||
|
BatchNorm
|
||||||
|
Dropout
|
||||||
|
LayerNorm
|
||||||
|
```
|
||||||
|
|
||||||
## Activation Functions
|
## Activation Functions
|
||||||
|
|
||||||
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
||||||
@ -50,5 +71,7 @@ These layers don't affect the structure of the network but may improve training
|
|||||||
Flux.testmode!
|
Flux.testmode!
|
||||||
BatchNorm
|
BatchNorm
|
||||||
Dropout
|
Dropout
|
||||||
|
AlphaDropout
|
||||||
LayerNorm
|
LayerNorm
|
||||||
|
GroupNorm
|
||||||
```
|
```
|
||||||
|
@ -77,7 +77,7 @@ If you use the `RNN(10, 5)` constructor – as opposed to `RNNCell` – you'll s
|
|||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> RNN(10, 5)
|
julia> RNN(10, 5)
|
||||||
Recur(RNNCell(Dense(15, 5)))
|
Recur(RNNCell(10, 5, tanh))
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sequences
|
## Sequences
|
||||||
@ -114,3 +114,13 @@ truncate!(m)
|
|||||||
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
||||||
|
|
||||||
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
||||||
|
|
||||||
|
In general, when training with recurrent layers in your model, you'll want to call `reset!` or `truncate!` for each loss calculation:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
function loss(x,y)
|
||||||
|
l = Flux.mse(m(x), y)
|
||||||
|
Flux.reset!(m)
|
||||||
|
return l
|
||||||
|
end
|
||||||
|
```
|
||||||
|
76
docs/src/performance.md
Normal file
76
docs/src/performance.md
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Performance Tips
|
||||||
|
|
||||||
|
All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/manual/performance-tips/).
|
||||||
|
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
||||||
|
Below follow some Flux specific tips/reminders.
|
||||||
|
|
||||||
|
## Don't use more precision than you need.
|
||||||
|
|
||||||
|
Flux works great with all kinds of number types.
|
||||||
|
But often you do not need to be working with say `Float64` (let alone `BigFloat`).
|
||||||
|
Switching to `Float32` can give you a significant speed up,
|
||||||
|
not because the operations are faster, but because the memory usage is halved.
|
||||||
|
Which means allocations occur much faster.
|
||||||
|
And you use less memory.
|
||||||
|
|
||||||
|
|
||||||
|
## Make sure your custom activation functions preserve the type of their inputs
|
||||||
|
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||||
|
they should also preserve the type of their inputs.
|
||||||
|
|
||||||
|
A very artificial example using an activatioon function like
|
||||||
|
|
||||||
|
```
|
||||||
|
my_tanh(x) = Float64(tanh(x))
|
||||||
|
```
|
||||||
|
|
||||||
|
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
|
||||||
|
because it results in having to use slow mixed type multiplication in the dense layers.
|
||||||
|
|
||||||
|
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
|
||||||
|
you will see a large slow-down
|
||||||
|
|
||||||
|
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
|
||||||
|
E.g. the following will have run into the same problem as above:
|
||||||
|
|
||||||
|
```
|
||||||
|
leaky_tanh(x) = 0.01x + tanh(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
While one could change your activation function (e.g. to use `0.01f0x`) to avoid this when ever your inputs change,
|
||||||
|
the idiomatic (and safe way) is to use `oftype`.
|
||||||
|
|
||||||
|
```
|
||||||
|
leaky_tanh(x) = oftype(x/1, 0.01) + tanh(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Evaluate batches as Matrices of features, rather than sequences of Vector features
|
||||||
|
|
||||||
|
While it can sometimes be tempting to process your observations (feature vectors) one at a time
|
||||||
|
e.g.
|
||||||
|
```julia
|
||||||
|
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
|
||||||
|
sum(zip(xs, ys)) do (x, y_target)
|
||||||
|
y_pred = model(x) # evaluate the model
|
||||||
|
return loss(y_pred, y_target)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
It is much faster to concatenate them into a matrix,
|
||||||
|
as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications.
|
||||||
|
Even though this means allocating new memory to store them contiguously.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
x_batch = reduce(hcat, xs)
|
||||||
|
y_batch = reduce(hcat, ys)
|
||||||
|
...
|
||||||
|
function loss_total(x_batch::Matrix, y_batch::Matrix)
|
||||||
|
y_preds = model(x_batch)
|
||||||
|
sum(loss.(y_preds, y_batch))
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
|
||||||
|
This will avoid the splatting penality, and will hit the optimised `reduce` method.
|
@ -49,5 +49,12 @@ All optimisers return an object that, when passed to `train!`, will update the p
|
|||||||
Descent
|
Descent
|
||||||
Momentum
|
Momentum
|
||||||
Nesterov
|
Nesterov
|
||||||
|
RMSProp
|
||||||
ADAM
|
ADAM
|
||||||
|
AdaMax
|
||||||
|
ADAGrad
|
||||||
|
ADADelta
|
||||||
|
AMSGrad
|
||||||
|
NADAM
|
||||||
|
ADAMW
|
||||||
```
|
```
|
||||||
|
@ -93,3 +93,11 @@ evalcb() = @show(loss(test_x, test_y))
|
|||||||
Flux.train!(objective, ps, data, opt,
|
Flux.train!(objective, ps, data, opt,
|
||||||
cb = throttle(evalcb, 5))
|
cb = throttle(evalcb, 5))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Calling `Flux.stop()` in a callback will exit the training loop early.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
cb = function ()
|
||||||
|
accuracy() > 0.9 && Flux.stop()
|
||||||
|
end
|
||||||
|
```
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
journal = {arXiv},
|
journal = {arXiv},
|
||||||
volume = {abs/11712.03112},
|
volume = {abs/11712.03112},
|
||||||
year = {2017},
|
year = {2017},
|
||||||
url = {http://arxiv.org/abs/1712.03112},
|
url = {https://arxiv.org/abs/1712.03112},
|
||||||
}
|
}
|
||||||
|
|
||||||
@online{MLPL,
|
@online{MLPL,
|
||||||
@ -29,7 +29,7 @@
|
|||||||
author = {Mike Innes and others},
|
author = {Mike Innes and others},
|
||||||
title = {Generic GPU Kernels},
|
title = {Generic GPU Kernels},
|
||||||
year = 2017,
|
year = 2017,
|
||||||
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
|
url = {https://mikeinnes.github.io/2017/08/24/cudanative.html},
|
||||||
urldate = {2018-02-16}
|
urldate = {2018-02-16}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,15 +6,14 @@ using Base: tail
|
|||||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
params, mapleaves, cpu, gpu, f32, f64
|
params, mapleaves, cpu, gpu, f32, f64
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
using Tracker
|
||||||
using .Tracker
|
using Tracker: data
|
||||||
using .Tracker: data
|
|
||||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
module CUDA
|
module CUDA
|
||||||
|
|
||||||
using ..CuArrays
|
using ..CuArrays
|
||||||
|
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||||
using Pkg.TOML
|
using Pkg.TOML
|
||||||
|
|
||||||
function version_check()
|
function version_check()
|
||||||
minor_version = 9
|
major_version = 1
|
||||||
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
|
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
|
||||||
project = TOML.parse(String(read(project)))
|
project = TOML.parse(String(read(project)))
|
||||||
version = VersionNumber(get(project, "version", "0.0.0"))
|
version = VersionNumber(get(project, "version", "0.0.0"))
|
||||||
if !(version.major == 0 && version.minor == minor_version)
|
if version.major != major_version
|
||||||
@warn """
|
@warn """
|
||||||
Flux is only supported with CuArrays v0.$minor_version.
|
Flux is only supported with CuArrays v$major_version.x.
|
||||||
Try running `] pin CuArrays@0.$minor_version`.
|
Try running `] pin CuArrays@$major_version`.
|
||||||
"""
|
"""
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
||||||
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
||||||
desc = DropoutDesc(d[], states)
|
desc = DropoutDesc(d[], states)
|
||||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
|
||||||
desc,handle(),ρ,states,length(states),seed)
|
desc,handle(),ρ,states,length(states),seed)
|
||||||
finalizer(desc) do x
|
finalizer(desc) do x
|
||||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
mean = zeros(CuArray{T}, dims...)
|
mean = zeros(CuArray{T}, dims...)
|
||||||
ivar = ones(CuArray{T}, dims...)
|
ivar = ones(CuArray{T}, dims...)
|
||||||
else
|
else
|
||||||
mean = C_NULL
|
mean = CU_NULL
|
||||||
ivar = C_NULL
|
ivar = CU_NULL
|
||||||
end
|
end
|
||||||
|
|
||||||
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||||
Cdouble, Ptr{T}, Ptr{T},
|
Cdouble, CuPtr{T}, CuPtr{T},
|
||||||
Cdouble, Ptr{T}, Ptr{T}),
|
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||||
handle(), BATCHNORM_SPATIAL,
|
handle(), BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
xd, x,
|
xd, x,
|
||||||
@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||||
Ptr{T}, Ptr{T},
|
CuPtr{T}, CuPtr{T},
|
||||||
Cdouble),
|
Cdouble),
|
||||||
handle(), BATCHNORM_SPATIAL,
|
handle(), BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
mean, ivar = cache.mean, cache.ivar
|
mean, ivar = cache.mean, cache.ivar
|
||||||
info("mean and ivar are fetched from the cache")
|
info("mean and ivar are fetched from the cache")
|
||||||
else
|
else
|
||||||
mean, ivar = C_NULL, C_NULL
|
mean, ivar = CU_NULL, CU_NULL
|
||||||
end
|
end
|
||||||
|
|
||||||
if eps < BATCHNORM_MIN_EPS
|
if eps < BATCHNORM_MIN_EPS
|
||||||
@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
|
||||||
Cdouble, Ptr{T}, Ptr{T}),
|
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||||
handle(), BATCHNORM_SPATIAL,
|
handle(), BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
Ref(T(dalpha)), Ref(T(dbeta)),
|
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||||
|
@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
|
|||||||
if reserve == nothing
|
if reserve == nothing
|
||||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Csize_t),
|
CuPtr{Nothing}, Csize_t),
|
||||||
handle(), rnn, seqlen,
|
handle(), rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace))
|
workspace, length(workspace))
|
||||||
else
|
else
|
||||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||||
handle(), rnn, seqlen,
|
handle(), rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace), reserve, length(reserve))
|
workspace, length(workspace), reserve, length(reserve))
|
||||||
@ -121,7 +121,7 @@ end
|
|||||||
|
|
||||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||||
|
|
||||||
hDesc(h::Nothing) = C_NULL, C_NULL
|
hDesc(h::Nothing) = C_NULL, CU_NULL
|
||||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||||
function hDesc(h::CuArray)
|
function hDesc(h::CuArray)
|
||||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||||
@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
|
|||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
|
||||||
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||||
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||||
end
|
end
|
||||||
@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
|
|||||||
workspace, reserve) where T
|
workspace, reserve) where T
|
||||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
Ptr{Ptr{Nothing}}, CuPtr{T}, #x
|
||||||
Ptr{Nothing}, Ptr{T}, #hx
|
Ptr{Nothing}, CuPtr{T}, #hx
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
Ptr{Ptr{Nothing}}, CuPtr{T}, #y
|
||||||
Ptr{Nothing}, Csize_t, #ws
|
CuPtr{Nothing}, Csize_t, #ws
|
||||||
Ptr{Nothing}, Ptr{T}, #dw
|
Ptr{Nothing}, CuPtr{T}, #dw
|
||||||
Ptr{Nothing}, Csize_t), #rs
|
CuPtr{Nothing}, Csize_t), #rs
|
||||||
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
||||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||||
end
|
end
|
||||||
|
@ -39,4 +39,7 @@ include("tree.jl")
|
|||||||
include("sentiment.jl")
|
include("sentiment.jl")
|
||||||
using .Sentiment
|
using .Sentiment
|
||||||
|
|
||||||
|
include("iris.jl")
|
||||||
|
export Iris
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -19,7 +19,7 @@ function load()
|
|||||||
@info "Downloading CMUDict dataset"
|
@info "Downloading CMUDict dataset"
|
||||||
mkpath(deps("cmudict"))
|
mkpath(deps("cmudict"))
|
||||||
for (x, hash) in suffixes_and_hashes
|
for (x, hash) in suffixes_and_hashes
|
||||||
download_and_verify("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
download_and_verify("$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||||
deps("cmudict", "cmudict$x"), hash)
|
deps("cmudict", "cmudict$x"), hash)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
86
src/data/iris.jl
Normal file
86
src/data/iris.jl
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Iris
|
||||||
|
|
||||||
|
Fisher's classic iris dataset.
|
||||||
|
|
||||||
|
Measurements from 3 different species of iris: setosa, versicolor and
|
||||||
|
virginica. There are 50 examples of each species.
|
||||||
|
|
||||||
|
There are 4 measurements for each example: sepal length, sepal width, petal
|
||||||
|
length and petal width. The measurements are in centimeters.
|
||||||
|
|
||||||
|
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
||||||
|
|
||||||
|
"""
|
||||||
|
module Iris
|
||||||
|
|
||||||
|
using DelimitedFiles
|
||||||
|
using ..Data: deps, download_and_verify
|
||||||
|
|
||||||
|
# Uncomment if the iris.data file is cached to cache.julialang.org.
|
||||||
|
const cache_prefix = "https://cache.julialang.org/"
|
||||||
|
|
||||||
|
function load()
|
||||||
|
isfile(deps("iris.data")) && return
|
||||||
|
|
||||||
|
@info "Downloading iris dataset."
|
||||||
|
download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
|
||||||
|
deps("iris.data"),
|
||||||
|
"6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0")
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
labels()
|
||||||
|
|
||||||
|
Get the labels of the iris dataset, a 150 element array of strings listing the
|
||||||
|
species of each example.
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> labels = Flux.Data.Iris.labels();
|
||||||
|
|
||||||
|
julia> summary(labels)
|
||||||
|
"150-element Array{String,1}"
|
||||||
|
|
||||||
|
julia> labels[1]
|
||||||
|
"Iris-setosa"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
function labels()
|
||||||
|
load()
|
||||||
|
iris = readdlm(deps("iris.data"), ',')
|
||||||
|
Vector{String}(iris[1:end, end])
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
features()
|
||||||
|
|
||||||
|
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
||||||
|
elements. It has a row for each feature (sepal length, sepal width,
|
||||||
|
petal length, petal width) and a column for each example.
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> features = Flux.Data.Iris.features();
|
||||||
|
|
||||||
|
julia> summary(features)
|
||||||
|
"4×150 Array{Float64,2}"
|
||||||
|
|
||||||
|
julia> features[:, 1]
|
||||||
|
4-element Array{Float64,1}:
|
||||||
|
5.1
|
||||||
|
3.5
|
||||||
|
1.4
|
||||||
|
0.2
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
function features()
|
||||||
|
load()
|
||||||
|
iris = readdlm(deps("iris.data"), ',')
|
||||||
|
Matrix{Float64}(iris[1:end, 1:4]')
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
@ -40,7 +40,24 @@ function Base.show(io::IO, c::Chain)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
|
||||||
|
# This is a temporary and naive implementation
|
||||||
|
# it might be replaced in the future for better performance
|
||||||
|
# see issue https://github.com/FluxML/Flux.jl/issues/702
|
||||||
|
# Johnny Chen -- @johnnychen94
|
||||||
|
"""
|
||||||
|
activations(c::Chain, input)
|
||||||
|
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||||
|
"""
|
||||||
|
function activations(c::Chain, input)
|
||||||
|
rst = []
|
||||||
|
for l in c
|
||||||
|
x = get(rst, length(rst), input)
|
||||||
|
push!(rst, l(x))
|
||||||
|
end
|
||||||
|
return rst
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dense(in::Integer, out::Integer, σ = identity)
|
Dense(in::Integer, out::Integer, σ = identity)
|
||||||
@ -88,6 +105,14 @@ function Base.show(io::IO, l::Dense)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Try to avoid hitting generic matmul in some simple cases
|
||||||
|
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
|
||||||
|
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
invoke(a, Tuple{AbstractArray}, x)
|
||||||
|
|
||||||
|
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
a(T.(x))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Diagonal(in::Integer)
|
Diagonal(in::Integer)
|
||||||
|
|
||||||
@ -117,10 +142,50 @@ function Base.show(io::IO, l::Diagonal)
|
|||||||
print(io, "Diagonal(", length(l.α), ")")
|
print(io, "Diagonal(", length(l.α), ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
# Try to avoid hitting generic matmul in some simple cases
|
|
||||||
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
|
|
||||||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
|
||||||
invoke(a, Tuple{AbstractArray}, x)
|
|
||||||
|
|
||||||
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
"""
|
||||||
a(T.(x))
|
Maxout(over)
|
||||||
|
|
||||||
|
`Maxout` is a neural network layer, which has a number of internal layers,
|
||||||
|
which all have the same input, and the maxout returns the elementwise maximium
|
||||||
|
of the internal layers' outputs.
|
||||||
|
|
||||||
|
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
|
||||||
|
2013. Maxout networks.
|
||||||
|
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
|
||||||
|
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
|
||||||
|
https://arxiv.org/pdf/1302.4389.pdf
|
||||||
|
"""
|
||||||
|
struct Maxout{FS<:Tuple}
|
||||||
|
over::FS
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Maxout(f, n_alts)
|
||||||
|
|
||||||
|
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
|
||||||
|
The function takes no arguement and should return some callable layer.
|
||||||
|
Conventionally this is a linear dense layer.
|
||||||
|
|
||||||
|
For example the following example which
|
||||||
|
will construct a `Maxout` layer over 4 internal dense linear layers,
|
||||||
|
each identical in structure (784 inputs, 128 outputs).
|
||||||
|
```julia
|
||||||
|
insize = 784
|
||||||
|
outsize = 128
|
||||||
|
Maxout(()->Dense(insize, outsize), 4)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
function Maxout(f, n_alts)
|
||||||
|
over = Tuple(f() for _ in 1:n_alts)
|
||||||
|
return Maxout(over)
|
||||||
|
end
|
||||||
|
|
||||||
|
@treelike Maxout
|
||||||
|
|
||||||
|
function (mo::Maxout)(input::AbstractArray)
|
||||||
|
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||||
|
end
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||||
|
|
||||||
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
|
||||||
|
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Conv(size, in=>out)
|
Conv(size, in=>out)
|
||||||
Conv(size, in=>out, relu)
|
Conv(size, in=>out, relu)
|
||||||
@ -12,23 +9,36 @@ expand(N, i::Integer) = ntuple(_ -> i, N)
|
|||||||
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
|
||||||
be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array.
|
giving us a 16-channel output. Output is activated with ReLU.
|
||||||
|
|
||||||
|
size = (2,2)
|
||||||
|
in = 1
|
||||||
|
out = 16
|
||||||
|
Conv((2, 2), 1=>16, relu)
|
||||||
|
|
||||||
|
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||||
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
struct Conv{N,F,A,V}
|
struct Conv{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
weight::A
|
weight::A
|
||||||
bias::V
|
bias::V
|
||||||
stride::NTuple{N,Int}
|
stride::NTuple{N,Int}
|
||||||
pad::NTuple{N,Int}
|
pad::NTuple{M,Int}
|
||||||
dilation::NTuple{N,Int}
|
dilation::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
|
|
||||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||||
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
stride = expand(Val(N-2), stride)
|
||||||
|
pad = expand(Val(2*(N-2)), pad)
|
||||||
|
dilation = expand(Val(N-2), dilation)
|
||||||
|
return Conv(σ, w, b, stride, pad, dilation)
|
||||||
|
end
|
||||||
|
|
||||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||||
@ -41,7 +51,8 @@ function (c::Conv)(x::AbstractArray)
|
|||||||
# TODO: breaks gpu broadcast :(
|
# TODO: breaks gpu broadcast :(
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||||
|
σ.(conv(x, c.weight, cdims) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Conv)
|
function Base.show(io::IO, l::Conv)
|
||||||
@ -67,18 +78,22 @@ Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
|||||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
struct ConvTranspose{N,F,A,V}
|
struct ConvTranspose{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
weight::A
|
weight::A
|
||||||
bias::V
|
bias::V
|
||||||
stride::NTuple{N,Int}
|
stride::NTuple{N,Int}
|
||||||
pad::NTuple{N,Int}
|
pad::NTuple{M,Int}
|
||||||
dilation::NTuple{N,Int}
|
dilation::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
|
|
||||||
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
function ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||||
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
stride = expand(Val(N-2), stride)
|
||||||
|
pad = expand(Val(2*(N-2)), pad)
|
||||||
|
dilation = expand(Val(N-2), dilation)
|
||||||
|
return ConvTranspose(σ, w, b, stride, pad, dilation)
|
||||||
|
end
|
||||||
|
|
||||||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||||
@ -87,10 +102,25 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
|||||||
|
|
||||||
@treelike ConvTranspose
|
@treelike ConvTranspose
|
||||||
|
|
||||||
|
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||||
|
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||||
|
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
|
||||||
|
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
|
||||||
|
C_in = size(c.weight)[end-1]
|
||||||
|
batch_size = size(x)[end]
|
||||||
|
# Create DenseConvDims() that looks like the corresponding conv()
|
||||||
|
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
||||||
|
stride=c.stride,
|
||||||
|
padding=c.pad,
|
||||||
|
dilation=c.dilation,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
function (c::ConvTranspose)(x::AbstractArray)
|
function (c::ConvTranspose)(x::AbstractArray)
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
cdims = conv_transpose_dims(c, x)
|
||||||
|
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::ConvTranspose)
|
function Base.show(io::IO, l::ConvTranspose)
|
||||||
@ -119,26 +149,32 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
|||||||
|
|
||||||
Takes the keyword arguments `pad` and `stride`.
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
"""
|
"""
|
||||||
struct DepthwiseConv{N,F,A,V}
|
struct DepthwiseConv{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
weight::A
|
weight::A
|
||||||
bias::V
|
bias::V
|
||||||
stride::NTuple{N,Int}
|
stride::NTuple{N,Int}
|
||||||
pad::NTuple{N,Int}
|
pad::NTuple{M,Int}
|
||||||
|
dilation::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
|
|
||||||
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
stride = 1, pad = 0) where {T,N} =
|
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
stride = expand(Val(N-2), stride)
|
||||||
|
pad = expand(Val(2*(N-2)), pad)
|
||||||
|
dilation = expand(Val(N-2), dilation)
|
||||||
|
return DepthwiseConv(σ, w, b, stride, pad, dilation)
|
||||||
|
end
|
||||||
|
|
||||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||||
stride = 1, pad = 0) where N =
|
stride = 1, pad = 0, dilation = 1) where N =
|
||||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||||
stride = stride, pad = pad)
|
stride = stride, pad = pad, dilation=dilation)
|
||||||
|
|
||||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||||
stride::NTuple{N,Integer} = map(_->1,k),
|
stride::NTuple{N,Integer} = map(_->1,k),
|
||||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
pad::NTuple{N,Integer} = map(_->0,2 .* k),
|
||||||
|
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
|
||||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||||
stride = stride, pad = pad)
|
stride = stride, pad = pad)
|
||||||
|
|
||||||
@ -146,7 +182,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
|
|||||||
|
|
||||||
function (c::DepthwiseConv)(x)
|
function (c::DepthwiseConv)(x)
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||||
|
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::DepthwiseConv)
|
function Base.show(io::IO, l::DepthwiseConv)
|
||||||
@ -156,6 +193,12 @@ function Base.show(io::IO, l::DepthwiseConv)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
invoke(a, Tuple{AbstractArray}, x)
|
||||||
|
|
||||||
|
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
a(T.(x))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MaxPool(k)
|
MaxPool(k)
|
||||||
|
|
||||||
@ -163,16 +206,23 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
|
|||||||
|
|
||||||
Takes the keyword arguments `pad` and `stride`.
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
"""
|
"""
|
||||||
struct MaxPool{N}
|
struct MaxPool{N,M}
|
||||||
k::NTuple{N,Int}
|
k::NTuple{N,Int}
|
||||||
pad::NTuple{N,Int}
|
pad::NTuple{M,Int}
|
||||||
stride::NTuple{N,Int}
|
stride::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
|
|
||||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||||
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
stride = expand(Val(N), stride)
|
||||||
|
pad = expand(Val(2*N), pad)
|
||||||
|
|
||||||
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
return MaxPool(k, pad, stride)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (m::MaxPool)(x)
|
||||||
|
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||||
|
return maxpool(x, pdims)
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, m::MaxPool)
|
function Base.show(io::IO, m::MaxPool)
|
||||||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
@ -185,16 +235,22 @@ Mean pooling layer. `k` stands for the size of the window for each dimension of
|
|||||||
|
|
||||||
Takes the keyword arguments `pad` and `stride`.
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
"""
|
"""
|
||||||
struct MeanPool{N}
|
struct MeanPool{N,M}
|
||||||
k::NTuple{N,Int}
|
k::NTuple{N,Int}
|
||||||
pad::NTuple{N,Int}
|
pad::NTuple{M,Int}
|
||||||
stride::NTuple{N,Int}
|
stride::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
|
|
||||||
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||||
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
stride = expand(Val(N), stride)
|
||||||
|
pad = expand(Val(2*N), pad)
|
||||||
|
return MeanPool(k, pad, stride)
|
||||||
|
end
|
||||||
|
|
||||||
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
function (m::MeanPool)(x)
|
||||||
|
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||||
|
return meanpool(x, pdims)
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, m::MeanPool)
|
function Base.show(io::IO, m::MeanPool)
|
||||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
|
@ -43,6 +43,37 @@ end
|
|||||||
|
|
||||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||||
|
|
||||||
|
"""
|
||||||
|
AlphaDropout(p)
|
||||||
|
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||||
|
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||||
|
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||||
|
"""
|
||||||
|
mutable struct AlphaDropout{F}
|
||||||
|
p::F
|
||||||
|
active::Bool
|
||||||
|
end
|
||||||
|
|
||||||
|
function AlphaDropout(p)
|
||||||
|
@assert 0 ≤ p ≤ 1
|
||||||
|
AlphaDropout(p,true)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (a::AlphaDropout)(x)
|
||||||
|
a.active || return x
|
||||||
|
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||||
|
α = eltype(x)(1.6732632423543772848170429916717)
|
||||||
|
α1 = eltype(x)(-λ*α)
|
||||||
|
noise = randn(eltype(x), size(x))
|
||||||
|
x = @. x*(noise > (1 - a.p)) + α1 * (noise <= (1 - a.p))
|
||||||
|
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
|
||||||
|
B = -A * α1 * (1 - a.p)
|
||||||
|
x = @. A * x + B
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
|
_testmode!(a::AlphaDropout, test) = (a.active = !test)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
LayerNorm(h::Integer)
|
LayerNorm(h::Integer)
|
||||||
|
|
||||||
@ -113,34 +144,32 @@ BatchNorm(chs::Integer, λ = identity;
|
|||||||
function (BN::BatchNorm)(x)
|
function (BN::BatchNorm)(x)
|
||||||
size(x, ndims(x)-1) == length(BN.β) ||
|
size(x, ndims(x)-1) == length(BN.β) ||
|
||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
γ, β = BN.γ, BN.β
|
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ones(Int, dims)
|
||||||
affine_shape[end-1] = channels
|
affine_shape[end-1] = channels
|
||||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||||
|
γ = reshape(BN.γ, affine_shape...)
|
||||||
|
β = reshape(BN.β, affine_shape...)
|
||||||
if !BN.active
|
if !BN.active
|
||||||
μ = reshape(BN.μ, affine_shape...)
|
μ = reshape(BN.μ, affine_shape...)
|
||||||
σ² = reshape(BN.σ², affine_shape...)
|
σ² = reshape(BN.σ², affine_shape...)
|
||||||
|
ϵ = BN.ϵ
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, dims = axes)
|
μ = mean(x, dims = axes)
|
||||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||||
|
ϵ = data(convert(T, BN.ϵ))
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = data(convert(T, BN.momentum))
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
|
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
|
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
# This is intentionally not fused because of an extreme slowdown doing so
|
λ.(γ .* x̂ .+ β)
|
||||||
λ.(temp .+ reshape(β, affine_shape...))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -157,3 +186,209 @@ function Base.show(io::IO, l::BatchNorm)
|
|||||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
InstanceNorm(channels::Integer, σ = identity;
|
||||||
|
initβ = zeros, initγ = ones,
|
||||||
|
ϵ = 1e-8, momentum = .1)
|
||||||
|
|
||||||
|
Instance Normalization layer. The `channels` input should be the size of the
|
||||||
|
channel dimension in your data (see below).
|
||||||
|
|
||||||
|
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||||
|
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||||
|
it's the usual channel dimension.)
|
||||||
|
|
||||||
|
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||||
|
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||||
|
per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
|
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```julia
|
||||||
|
m = Chain(
|
||||||
|
Dense(28^2, 64),
|
||||||
|
InstanceNorm(64, relu),
|
||||||
|
Dense(64, 10),
|
||||||
|
InstanceNorm(10),
|
||||||
|
softmax)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
|
|
||||||
|
mutable struct InstanceNorm{F,V,W,N}
|
||||||
|
λ::F # activation function
|
||||||
|
β::V # bias
|
||||||
|
γ::V # scale
|
||||||
|
μ::W # moving mean
|
||||||
|
σ²::W # moving std
|
||||||
|
ϵ::N
|
||||||
|
momentum::N
|
||||||
|
active::Bool
|
||||||
|
end
|
||||||
|
|
||||||
|
InstanceNorm(chs::Integer, λ = identity;
|
||||||
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
|
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
|
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||||
|
|
||||||
|
function (in::InstanceNorm)(x)
|
||||||
|
size(x, ndims(x)-1) == length(in.β) ||
|
||||||
|
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
|
ndims(x) > 2 ||
|
||||||
|
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||||
|
# these are repeated later on depending on the batch size
|
||||||
|
dims = length(size(x))
|
||||||
|
c = size(x, dims-1)
|
||||||
|
bs = size(x, dims)
|
||||||
|
affine_shape = ones(Int, dims)
|
||||||
|
affine_shape[end-1] = c
|
||||||
|
affine_shape[end] = bs
|
||||||
|
m = prod(size(x)[1:end-2])
|
||||||
|
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||||
|
|
||||||
|
if !in.active
|
||||||
|
μ = expand_inst(in.μ, affine_shape)
|
||||||
|
σ² = expand_inst(in.σ², affine_shape)
|
||||||
|
ϵ = in.ϵ
|
||||||
|
else
|
||||||
|
T = eltype(x)
|
||||||
|
|
||||||
|
ϵ = data(convert(T, in.ϵ))
|
||||||
|
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||||
|
μ = mean(x, dims = axes)
|
||||||
|
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||||
|
|
||||||
|
# update moving mean/std
|
||||||
|
mtm = data(convert(T, in.momentum))
|
||||||
|
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||||
|
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||||
|
end
|
||||||
|
|
||||||
|
let λ = in.λ
|
||||||
|
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
|
λ.(γ .* x̂ .+ β)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
children(in::InstanceNorm) =
|
||||||
|
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
||||||
|
|
||||||
|
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||||
|
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
||||||
|
|
||||||
|
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
|
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||||
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Group Normalization.
|
||||||
|
This layer can outperform Batch-Normalization and Instance-Normalization.
|
||||||
|
|
||||||
|
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||||
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||||
|
ϵ = 1f-5, momentum = 0.1f0)
|
||||||
|
|
||||||
|
``chs`` is the number of channels, the channel dimension of your input.
|
||||||
|
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||||
|
|
||||||
|
``G`` is the number of groups along which the statistics would be computed.
|
||||||
|
The number of channels must be an integer multiple of the number of groups.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||||
|
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||||
|
```
|
||||||
|
|
||||||
|
Link : https://arxiv.org/pdf/1803.08494.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
mutable struct GroupNorm{F,V,W,N,T}
|
||||||
|
G::T # number of groups
|
||||||
|
λ::F # activation function
|
||||||
|
β::V # bias
|
||||||
|
γ::V # scale
|
||||||
|
μ::W # moving mean
|
||||||
|
σ²::W # moving std
|
||||||
|
ϵ::N
|
||||||
|
momentum::N
|
||||||
|
active::Bool
|
||||||
|
end
|
||||||
|
|
||||||
|
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||||
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
|
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
|
zeros(G,1), ones(G,1), ϵ, momentum, true)
|
||||||
|
|
||||||
|
function(gn::GroupNorm)(x)
|
||||||
|
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||||
|
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
|
||||||
|
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
|
||||||
|
|
||||||
|
dims = length(size(x))
|
||||||
|
groups = gn.G
|
||||||
|
channels = size(x, dims-1)
|
||||||
|
batches = size(x,dims)
|
||||||
|
channels_per_group = div(channels,groups)
|
||||||
|
affine_shape = ones(Int, dims)
|
||||||
|
|
||||||
|
# Output reshaped to (W,H...,C/G,G,N)
|
||||||
|
affine_shape[end-1] = channels
|
||||||
|
|
||||||
|
μ_affine_shape = ones(Int,dims + 1)
|
||||||
|
μ_affine_shape[end-1] = groups
|
||||||
|
|
||||||
|
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||||
|
γ = reshape(gn.γ, affine_shape...)
|
||||||
|
β = reshape(gn.β, affine_shape...)
|
||||||
|
|
||||||
|
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||||
|
if !gn.active
|
||||||
|
og_shape = size(x)
|
||||||
|
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
|
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
|
ϵ = gn.ϵ
|
||||||
|
else
|
||||||
|
T = eltype(x)
|
||||||
|
og_shape = size(x)
|
||||||
|
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||||
|
μ = mean(y, dims = axes)
|
||||||
|
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||||
|
|
||||||
|
ϵ = data(convert(T, gn.ϵ))
|
||||||
|
# update moving mean/std
|
||||||
|
mtm = data(convert(T, gn.momentum))
|
||||||
|
|
||||||
|
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
|
||||||
|
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
|
||||||
|
end
|
||||||
|
|
||||||
|
let λ = gn.λ
|
||||||
|
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
|
|
||||||
|
# Reshape x̂
|
||||||
|
x̂ = reshape(x̂,og_shape)
|
||||||
|
λ.(γ .* x̂ .+ β)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
children(gn::GroupNorm) =
|
||||||
|
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
|
||||||
|
|
||||||
|
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
||||||
|
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
|
||||||
|
|
||||||
|
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::GroupNorm)
|
||||||
|
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||||
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
@ -153,7 +153,7 @@ Base.show(io::IO, l::LSTMCell) =
|
|||||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||||
exhibits a longer memory span over sequences.
|
exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
for a good overview of the internals.
|
for a good overview of the internals.
|
||||||
"""
|
"""
|
||||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||||
@ -194,7 +194,7 @@ Base.show(io::IO, l::GRUCell) =
|
|||||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||||
exhibits a longer memory span over sequences.
|
exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
for a good overview of the internals.
|
for a good overview of the internals.
|
||||||
"""
|
"""
|
||||||
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
||||||
|
@ -50,7 +50,7 @@ function normalise(x::AbstractArray; dims=1)
|
|||||||
return (x .- μ′) ./ σ′
|
return (x .- μ′) ./ σ′
|
||||||
end
|
end
|
||||||
|
|
||||||
function normalise(x::AbstractArray, dims=1)
|
function normalise(x::AbstractArray, dims)
|
||||||
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
||||||
normalise(x, dims = dims)
|
normalise(x, dims = dims)
|
||||||
end
|
end
|
||||||
|
@ -44,6 +44,29 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
|
|||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
onehot(l, labels[, unk])
|
||||||
|
|
||||||
|
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
|
||||||
|
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
|
||||||
|
it will error.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
|
3-element Flux.OneHotVector:
|
||||||
|
false
|
||||||
|
true
|
||||||
|
false
|
||||||
|
|
||||||
|
julia> onehot(:c, [:a, :b, :c])
|
||||||
|
3-element Flux.OneHotVector:
|
||||||
|
false
|
||||||
|
false
|
||||||
|
true
|
||||||
|
```
|
||||||
|
"""
|
||||||
function onehot(l, labels)
|
function onehot(l, labels)
|
||||||
i = something(findfirst(isequal(l), labels), 0)
|
i = something(findfirst(isequal(l), labels), 0)
|
||||||
i > 0 || error("Value $l is not in labels")
|
i > 0 || error("Value $l is not in labels")
|
||||||
@ -56,11 +79,43 @@ function onehot(l, labels, unk)
|
|||||||
OneHotVector(i, length(labels))
|
OneHotVector(i, length(labels))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
onehotbatch(ls, labels[, unk...])
|
||||||
|
|
||||||
|
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
|
||||||
|
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||||
|
3×3 Flux.OneHotMatrix:
|
||||||
|
false true false
|
||||||
|
true false true
|
||||||
|
false false false
|
||||||
|
|
||||||
|
```
|
||||||
|
"""
|
||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
Base.argmax(xs::OneHotVector) = xs.ix
|
Base.argmax(xs::OneHotVector) = xs.ix
|
||||||
|
|
||||||
|
"""
|
||||||
|
onecold(y[, labels = 1:length(y)])
|
||||||
|
|
||||||
|
Inverse operations of [`onehot`](@ref).
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> onecold([true, false, false], [:a, :b, :c])
|
||||||
|
:a
|
||||||
|
|
||||||
|
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||||
|
:c
|
||||||
|
```
|
||||||
|
"""
|
||||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||||
|
|
||||||
onecold(y::AbstractMatrix, labels...) =
|
onecold(y::AbstractMatrix, labels...) =
|
||||||
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::Momentum, x, Δ)
|
function apply!(o::Momentum, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||||
@. v = ρ * v - η * Δ
|
@. v = ρ * v - η * Δ
|
||||||
@. Δ = -v
|
@. Δ = -v
|
||||||
end
|
end
|
||||||
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::Nesterov, x, Δ)
|
function apply!(o::Nesterov, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||||
@. v = ρ*v - η*Δ
|
@. v = ρ*v - η*Δ
|
||||||
@. Δ = -d
|
@. Δ = -d
|
||||||
@ -66,7 +66,7 @@ end
|
|||||||
"""
|
"""
|
||||||
RMSProp(η = 0.001, ρ = 0.9)
|
RMSProp(η = 0.001, ρ = 0.9)
|
||||||
|
|
||||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||||
choice for recurrent networks.
|
choice for recurrent networks.
|
||||||
"""
|
"""
|
||||||
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::RMSProp, x, Δ)
|
function apply!(o::RMSProp, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
|||||||
|
|
||||||
function apply!(o::ADAGrad, x, Δ)
|
function apply!(o::ADAGrad, x, Δ)
|
||||||
η = o.eta
|
η = o.eta
|
||||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||||
@. acc += Δ^2
|
@. acc += Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
@ -155,7 +155,7 @@ end
|
|||||||
"""
|
"""
|
||||||
ADADelta(ρ = 0.9, ϵ = 1e-8)
|
ADADelta(ρ = 0.9, ϵ = 1e-8)
|
||||||
|
|
||||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||||
tuning.
|
tuning.
|
||||||
"""
|
"""
|
||||||
mutable struct ADADelta
|
mutable struct ADADelta
|
||||||
@ -323,5 +323,5 @@ WeightDecay() = WeightDecay(0)
|
|||||||
|
|
||||||
function apply!(o::WeightDecay, x, Δ)
|
function apply!(o::WeightDecay, x, Δ)
|
||||||
wd = o.wd
|
wd = o.wd
|
||||||
@. Δ += wd * x
|
@. Δ += wd * data(x)
|
||||||
end
|
end
|
||||||
|
@ -1,16 +1,23 @@
|
|||||||
using Juno
|
using Juno
|
||||||
import Flux.Tracker: data, grad, back!, update!
|
import Flux.Tracker: Params, gradient, data, update!
|
||||||
import Base.depwarn
|
import Base.depwarn
|
||||||
|
|
||||||
function update!(opt, x, x̄)
|
function update!(opt, x, x̄)
|
||||||
update!(x, apply!(opt, x, copy(data(x̄))))
|
update!(x, -apply!(opt, x, data(x̄)))
|
||||||
end
|
end
|
||||||
|
|
||||||
function _update_params!(opt, xs)
|
function update!(opt, xs::Params, gs)
|
||||||
for x in xs
|
for x in xs
|
||||||
Δ = apply!(opt, x.data, x.grad)
|
update!(opt, x, gs[x])
|
||||||
x.data .-= Δ
|
end
|
||||||
Δ .= 0
|
end
|
||||||
|
|
||||||
|
# Added as an internal API but everyone started using it.
|
||||||
|
function _update_params!(opt, xs)
|
||||||
|
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||||
|
for x in xs
|
||||||
|
update!(opt, x, Tracker.grad(x))
|
||||||
|
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
|
|||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
|
||||||
# while training; this just cleans that up.
|
|
||||||
macro interrupts(ex)
|
|
||||||
:(try $(esc(ex))
|
|
||||||
catch e
|
|
||||||
e isa InterruptException || rethrow()
|
|
||||||
throw(e)
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
|
|
||||||
struct StopException <: Exception end
|
struct StopException <: Exception end
|
||||||
"""
|
"""
|
||||||
stop()
|
stop()
|
||||||
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
|
|||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
function train!(loss, ps, data, opt; cb = () -> ())
|
function train!(loss, ps, data, opt; cb = () -> ())
|
||||||
|
ps = Params(ps)
|
||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
try
|
try
|
||||||
l = loss(d...)
|
gs = gradient(ps) do
|
||||||
@interrupts back!(l)
|
loss(d...)
|
||||||
_update_params!(opt, ps)
|
end
|
||||||
|
update!(opt, ps, gs)
|
||||||
if cb() == :stop
|
if cb() == :stop
|
||||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
break
|
break
|
||||||
|
@ -1,113 +0,0 @@
|
|||||||
module Tracker
|
|
||||||
|
|
||||||
using MacroTools
|
|
||||||
using MacroTools: @q, @forward
|
|
||||||
|
|
||||||
import Base: ==
|
|
||||||
|
|
||||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
|
||||||
jacobian, hessian, param, back!
|
|
||||||
|
|
||||||
tracker(x) = nothing
|
|
||||||
|
|
||||||
istracked(x) = tracker(x) ≠ nothing
|
|
||||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
|
||||||
grad(x) = grad(tracker(x))
|
|
||||||
grad(::Nothing) = nothing
|
|
||||||
data(x) = x
|
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
|
||||||
func::F
|
|
||||||
args::As
|
|
||||||
end
|
|
||||||
|
|
||||||
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
|
||||||
Call() = Call(nothing, ())
|
|
||||||
|
|
||||||
# When deserialising, the object_id changes
|
|
||||||
a::Call == b::Call = a.func == b.func && a.args == b.args
|
|
||||||
|
|
||||||
@inline (c::Call)() = c.func(data.(c.args)...)
|
|
||||||
|
|
||||||
mutable struct Tracked{T}
|
|
||||||
ref::UInt32
|
|
||||||
f::Call
|
|
||||||
isleaf::Bool
|
|
||||||
grad::T
|
|
||||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
|
||||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
|
||||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
|
||||||
end
|
|
||||||
|
|
||||||
istracked(x::Tracked) = true
|
|
||||||
isleaf(x::Tracked) = x.f == Call()
|
|
||||||
grad(x::Tracked) = x.grad
|
|
||||||
|
|
||||||
track(f::Call, x) = Tracked{typeof(x)}(f)
|
|
||||||
|
|
||||||
function _forward end
|
|
||||||
|
|
||||||
function track(f::F, xs...; kw...) where F
|
|
||||||
y, back = _forward(f, xs...; kw...)
|
|
||||||
track(Call(back, tracker.(xs)), y)
|
|
||||||
end
|
|
||||||
|
|
||||||
macro grad(ex)
|
|
||||||
@capture(shortdef(ex), (name_(args__) = body_) |
|
|
||||||
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
|
||||||
T == nothing && (T = [])
|
|
||||||
isexpr(name, :(::)) || (name = :(::typeof($name)))
|
|
||||||
insert!(args, 1+isexpr(args[1], :parameters) , name)
|
|
||||||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
|
||||||
end
|
|
||||||
|
|
||||||
include("idset.jl")
|
|
||||||
include("back.jl")
|
|
||||||
include("numeric.jl")
|
|
||||||
include("lib/real.jl")
|
|
||||||
include("lib/array.jl")
|
|
||||||
include("forward.jl")
|
|
||||||
|
|
||||||
"""
|
|
||||||
hook(f, x) -> x′
|
|
||||||
|
|
||||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
|
||||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
|
||||||
the sign of the gradient applied to `x`.
|
|
||||||
"""
|
|
||||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
|
||||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
|
||||||
|
|
||||||
"""
|
|
||||||
checkpoint(f, args...)
|
|
||||||
|
|
||||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
|
||||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
|
||||||
backward pass. This can be used to save memory in larger models.
|
|
||||||
"""
|
|
||||||
checkpoint(f, args...) = track(checkpoint, f, args...)
|
|
||||||
|
|
||||||
@grad function checkpoint(f, args...)
|
|
||||||
data(f(args...)), function (Δ)
|
|
||||||
y, back = forward(f, args...)
|
|
||||||
(nothing, back(Δ)...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
nobacksies(f, x) = track(nobacksies, f, x)
|
|
||||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
|
||||||
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
|
||||||
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
|
|
||||||
|
|
||||||
param(x::Number) = TrackedReal(float(x))
|
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
|
||||||
|
|
||||||
@grad identity(x) = data(x), Δ -> (Δ,)
|
|
||||||
param(x::TrackedReal) = track(identity, x)
|
|
||||||
param(x::TrackedArray) = track(identity, x)
|
|
||||||
|
|
||||||
import Adapt: adapt, adapt_structure
|
|
||||||
|
|
||||||
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
|
||||||
|
|
||||||
end
|
|
@ -1,210 +0,0 @@
|
|||||||
init_grad(x) = zero(x)
|
|
||||||
zero_grad!(x) = zero(x)
|
|
||||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
|
||||||
|
|
||||||
scan(c::Call) = foreach(scan, c.args)
|
|
||||||
|
|
||||||
function scan(x::Tracked)
|
|
||||||
x.isleaf && return
|
|
||||||
ref = x.ref += 1
|
|
||||||
if ref == 1
|
|
||||||
scan(x.f)
|
|
||||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
|
||||||
end
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
function scan(x)
|
|
||||||
istracked(x) && scan(tracker(x))
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
function back_(c::Call, Δ, once)
|
|
||||||
Δs = c.func(Δ)
|
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
|
||||||
error("Gradient is not a tuple of length $(length(c.args))")
|
|
||||||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
|
||||||
end
|
|
||||||
|
|
||||||
back_(::Call{Nothing}, Δ, once) = nothing
|
|
||||||
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
|
||||||
|
|
||||||
accum!(x, Δ) = x .+ Δ
|
|
||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
|
||||||
|
|
||||||
function back(x::Tracked, Δ, once)
|
|
||||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
|
||||||
ref = x.ref -= 1
|
|
||||||
grad = if isdefined(x, :grad)
|
|
||||||
x.grad = accum!(x.grad, Δ)
|
|
||||||
elseif ref > 0
|
|
||||||
x.grad = Δ
|
|
||||||
else
|
|
||||||
Δ
|
|
||||||
end
|
|
||||||
if ref == 0
|
|
||||||
back_(x.f, grad, once)
|
|
||||||
once && !x.isleaf && (x.f = Call(missing, ()))
|
|
||||||
end
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
back(::Nothing, Δ, once) = return
|
|
||||||
|
|
||||||
# Interface methods
|
|
||||||
|
|
||||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
|
||||||
# and `back` will silently fail to update.
|
|
||||||
# (but only if you re-use intermediate values between passes)
|
|
||||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
|
||||||
# from within a backpropagator)
|
|
||||||
|
|
||||||
function back!(x, Δ; once = true)
|
|
||||||
istracked(x) || return
|
|
||||||
scan(x)
|
|
||||||
back(tracker(x), Δ, once)
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
function gradient_(f, xs...)
|
|
||||||
xs = param.(data.(xs))
|
|
||||||
l = f(xs...)
|
|
||||||
losscheck(l)
|
|
||||||
back!(l)
|
|
||||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
|
||||||
grad.(xs))
|
|
||||||
end
|
|
||||||
|
|
||||||
# Out-of-place gradients
|
|
||||||
|
|
||||||
struct Params
|
|
||||||
order::Vector{Any}
|
|
||||||
params::IdSet{Any}
|
|
||||||
Params() = new([], IdSet())
|
|
||||||
end
|
|
||||||
|
|
||||||
@forward Params.order Base.iterate, Base.length
|
|
||||||
|
|
||||||
function Base.push!(ps::Params, x)
|
|
||||||
if !(x in ps.params)
|
|
||||||
push!(ps.order, x)
|
|
||||||
push!(ps.params, x)
|
|
||||||
end
|
|
||||||
return ps
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
|
||||||
|
|
||||||
Params(xs) = push!(Params(), xs...)
|
|
||||||
|
|
||||||
function Base.show(io::IO, ps::Params)
|
|
||||||
print(io, "Params([")
|
|
||||||
join(io, ps.order, ", ")
|
|
||||||
print(io, "])")
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Grads
|
|
||||||
grads::IdDict{Any,Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
|
||||||
|
|
||||||
Grads() = Grads(IdDict())
|
|
||||||
|
|
||||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
|
||||||
|
|
||||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
|
||||||
|
|
||||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
|
||||||
|
|
||||||
function Base.getindex(g::Grads, x)
|
|
||||||
istracked(x) || error("Object not tracked: $x")
|
|
||||||
g[tracker(x)]
|
|
||||||
end
|
|
||||||
|
|
||||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|
||||||
|
|
||||||
function back_(g::Grads, c::Call, Δ)
|
|
||||||
Δs = c.func(Δ)
|
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
|
||||||
error("Gradient is not a tuple of length $(length(c.args))")
|
|
||||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
|
||||||
end
|
|
||||||
|
|
||||||
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
|
||||||
|
|
||||||
function back(g::Grads, x::Tracked, Δ)
|
|
||||||
x.isleaf && (accum!(g, x, Δ); return)
|
|
||||||
ref = x.ref -= 1
|
|
||||||
if ref > 0 || haskey(g, x)
|
|
||||||
accum!(g, x, Δ)
|
|
||||||
ref == 0 && back_(g, x.f, g[x])
|
|
||||||
else
|
|
||||||
ref == 0 && back_(g, x.f, Δ)
|
|
||||||
end
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
back(::Grads, ::Nothing, _) = return
|
|
||||||
|
|
||||||
collectmemaybe(xs) = xs
|
|
||||||
|
|
||||||
function forward(f, ps::Params)
|
|
||||||
y = collectmemaybe(f())
|
|
||||||
y, function (Δ)
|
|
||||||
g = Grads(ps)
|
|
||||||
if istracked(y)
|
|
||||||
scan(y)
|
|
||||||
back(g, tracker(y), Δ)
|
|
||||||
end
|
|
||||||
return g
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function forward(f, args...)
|
|
||||||
args = param.(args)
|
|
||||||
y, back = forward(() -> f(args...), Params(args))
|
|
||||||
y, Δ -> getindex.(Ref(back(Δ)), args)
|
|
||||||
end
|
|
||||||
|
|
||||||
function losscheck(x)
|
|
||||||
x isa Real || error("Function output is not scalar")
|
|
||||||
isinf(x) && error("Loss is infinite")
|
|
||||||
isnan(x) && error("Loss is NaN")
|
|
||||||
end
|
|
||||||
|
|
||||||
function gradient_nested(f, args...)
|
|
||||||
y, back = forward(f, args...)
|
|
||||||
losscheck(y)
|
|
||||||
return back(1)
|
|
||||||
end
|
|
||||||
|
|
||||||
gradient(f, xs...; nest = false) =
|
|
||||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
|
||||||
|
|
||||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
|
||||||
|
|
||||||
# Jacobians and Hessians
|
|
||||||
|
|
||||||
import ..Flux
|
|
||||||
|
|
||||||
"""
|
|
||||||
J = jacobian(m,x)
|
|
||||||
|
|
||||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
|
||||||
"""
|
|
||||||
function jacobian(m,x)
|
|
||||||
xp = param(x)
|
|
||||||
y = m(xp)
|
|
||||||
k = length(y)
|
|
||||||
n = length(x)
|
|
||||||
J = Matrix{eltype(x)}(undef,k,n)
|
|
||||||
for i = 1:k
|
|
||||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
|
||||||
J[i,:] = xp.grad
|
|
||||||
xp.grad .= 0 # Reset gradient accumulator
|
|
||||||
end
|
|
||||||
J
|
|
||||||
end
|
|
||||||
|
|
||||||
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)
|
|
@ -1,53 +0,0 @@
|
|||||||
using ForwardDiff
|
|
||||||
|
|
||||||
seed(x::Real, ::Val) = Dual(x, true)
|
|
||||||
|
|
||||||
function seed(x, ::Val{N}, offset = 0) where N
|
|
||||||
map(x, reshape(1:length(x), size(x))) do x, i
|
|
||||||
Dual(x, ntuple(j -> j+offset == i, Val(N)))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
|
|
||||||
|
|
||||||
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
|
|
||||||
J = similar(xs, V, N, length(xs))
|
|
||||||
for i = 1:length(xs), j = 1:N
|
|
||||||
J[j, i] = xs[i].partials.values[j]
|
|
||||||
end
|
|
||||||
return map(x -> x.value, xs), J
|
|
||||||
end
|
|
||||||
|
|
||||||
function forward_jacobian(f, x, ::Val{N}) where N
|
|
||||||
y, _J = extract(f(seed(x, Val(N))))
|
|
||||||
J = similar(_J, length(x), length(y))
|
|
||||||
J[1:N,:] = _J
|
|
||||||
offset = 0
|
|
||||||
while offset + N < length(x)
|
|
||||||
offset += N
|
|
||||||
_, _J = extract(f(seed(x, Val(N), offset)))
|
|
||||||
range = (1+offset):min(N+offset,length(x))
|
|
||||||
J[range,:] = @view _J[range.-offset,:]
|
|
||||||
end
|
|
||||||
return y, J
|
|
||||||
end
|
|
||||||
|
|
||||||
function forward_jacobian(f, x)
|
|
||||||
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
|
|
||||||
forward_jacobian(f, x, Val(length(x)))
|
|
||||||
else
|
|
||||||
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
|
|
||||||
|
|
||||||
vec_scalar(x) = vec(x)
|
|
||||||
vec_scalar(x::Real) = [x]
|
|
||||||
reshape_scalar(x, y) = reshape(y, size(x))
|
|
||||||
reshape_scalar(x::Real, y) = y[]
|
|
||||||
|
|
||||||
@grad function forwarddiff(f, x)
|
|
||||||
y, J = forward_jacobian(f, data(x))
|
|
||||||
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
|
|
||||||
end
|
|
@ -1,28 +0,0 @@
|
|||||||
struct IdSet{T} <: AbstractSet{T}
|
|
||||||
dict::IdDict{T,Nothing}
|
|
||||||
IdSet{T}() where T = new(IdDict{T,Nothing}())
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.eltype(::IdSet{T}) where T = T
|
|
||||||
|
|
||||||
IdSet() = IdSet{Any}()
|
|
||||||
|
|
||||||
Base.push!(s::IdSet) = s
|
|
||||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
|
||||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
|
||||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
|
||||||
|
|
||||||
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
|
|
||||||
|
|
||||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
|
||||||
|
|
||||||
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
|
|
||||||
Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
|
||||||
|
|
||||||
@forward IdSet.dict Base.length
|
|
||||||
|
|
||||||
function Base.iterate(v::IdSet, state...)
|
|
||||||
y = Base.iterate(keys(v.dict), state...)
|
|
||||||
y === nothing && return nothing
|
|
||||||
return (y[1], y[2])
|
|
||||||
end
|
|
@ -1,516 +0,0 @@
|
|||||||
import Base: *
|
|
||||||
|
|
||||||
import LinearAlgebra
|
|
||||||
import LinearAlgebra: inv, det, logdet, logabsdet, \, /
|
|
||||||
|
|
||||||
using Statistics
|
|
||||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
|
||||||
|
|
||||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
|
||||||
tracker::Tracked{A}
|
|
||||||
data::A
|
|
||||||
grad::A
|
|
||||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
|
||||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
|
||||||
end
|
|
||||||
|
|
||||||
data(x::TrackedArray) = x.data
|
|
||||||
tracker(x::TrackedArray) = x.tracker
|
|
||||||
|
|
||||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
|
||||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
|
||||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
|
||||||
|
|
||||||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
|
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
|
||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
|
||||||
|
|
||||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
|
||||||
|
|
||||||
Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
|
|
||||||
|
|
||||||
Base.convert(::Type{<:TrackedArray}, x::TrackedArray) =
|
|
||||||
error("Not implemented: convert $(typeof(x)) to $T")
|
|
||||||
|
|
||||||
Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
|
|
||||||
TrackedArray(convert(A, x))
|
|
||||||
|
|
||||||
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
|
||||||
@isdefined(A) ?
|
|
||||||
print(io, "TrackedArray{…,$A}") :
|
|
||||||
invoke(show, Tuple{IO,DataType}, io, t)
|
|
||||||
|
|
||||||
function Base.summary(io::IO, x::TrackedArray)
|
|
||||||
print(io, "Tracked ")
|
|
||||||
summary(io, data(x))
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedArray)
|
|
||||||
show(io, data(x))
|
|
||||||
print(io, " (tracked)")
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.copy(x::TrackedArray) = x
|
|
||||||
|
|
||||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
|
||||||
error("Can't differentiate `setindex!`")
|
|
||||||
|
|
||||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
|
||||||
|
|
||||||
function update!(x::TrackedArray, Δ)
|
|
||||||
x.data .+= data(Δ)
|
|
||||||
tracker(x).grad .= 0
|
|
||||||
return x
|
|
||||||
end
|
|
||||||
|
|
||||||
# Fallthrough methods
|
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
|
||||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
|
|
||||||
size(data(x), i, j, is...)
|
|
||||||
|
|
||||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|
||||||
similar(data(x), dims...)
|
|
||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
|
||||||
|
|
||||||
for op in [:(==), :≈]
|
|
||||||
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
|
||||||
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
|
||||||
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
|
||||||
end
|
|
||||||
|
|
||||||
# Array Stdlib
|
|
||||||
|
|
||||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
|
||||||
|
|
||||||
@grad function getindex(xs::AbstractArray, i...)
|
|
||||||
data(xs)[i...], function (Δ)
|
|
||||||
Δ′ = zero(xs)
|
|
||||||
Δ′[i...] = data(Δ)
|
|
||||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
|
||||||
|
|
||||||
@grad function view(x::AbstractArray, inds...)
|
|
||||||
view(data(x), inds...), function (Δ)
|
|
||||||
grad_output = zero(x)
|
|
||||||
subgrad = view(grad_output, inds...)
|
|
||||||
subgrad[:] = data(Δ)
|
|
||||||
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
|
||||||
|
|
||||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
|
||||||
|
|
||||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
|
||||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|
||||||
|
|
||||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
|
||||||
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
|
||||||
|
|
||||||
det(xs::TrackedArray) = track(det, xs)
|
|
||||||
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
|
|
||||||
|
|
||||||
logdet(xs::TrackedArray) = track(logdet, xs)
|
|
||||||
@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),)
|
|
||||||
|
|
||||||
logabsdet(xs::TrackedArray) = track(logabsdet, xs)
|
|
||||||
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)
|
|
||||||
|
|
||||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
|
||||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
|
||||||
Δ′ = zero(xs)
|
|
||||||
S = size(xs)
|
|
||||||
|
|
||||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
|
||||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
|
||||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
|
||||||
# wrap around based on original size S.
|
|
||||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
|
||||||
Δ′[src_idx...] += val
|
|
||||||
end
|
|
||||||
(nobacksies(:repeat, Δ′),)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function combinations(xs, n)
|
|
||||||
n < 1 && return [[]]
|
|
||||||
cs = combinations(xs, n-1)
|
|
||||||
[[x, c...] for x in xs, c in cs]
|
|
||||||
end
|
|
||||||
|
|
||||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
|
|
||||||
cnames = map(_ -> gensym(), c)
|
|
||||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
|
|
||||||
track($f, $(cnames...), x, xs...)
|
|
||||||
end
|
|
||||||
|
|
||||||
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
|
|
||||||
cnames = map(_ -> gensym(), c)
|
|
||||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
|
|
||||||
track($f, $(cnames...), x, xs...)
|
|
||||||
end
|
|
||||||
|
|
||||||
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
|
|
||||||
cnames = map(_ -> gensym(), c)
|
|
||||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
|
|
||||||
track($f, $(cnames...), x, xs...)
|
|
||||||
end
|
|
||||||
|
|
||||||
@grad function vcat(xs...)
|
|
||||||
vcat(data.(xs)...), function (Δ)
|
|
||||||
start = 0
|
|
||||||
Δs = [begin
|
|
||||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
|
||||||
d = Δ[start+1:start+size(xsi,1), i...]
|
|
||||||
start += size(xsi, 1)
|
|
||||||
d
|
|
||||||
end for xsi in xs]
|
|
||||||
return (Δs...,)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@grad function hcat(xs...)
|
|
||||||
hcat(data.(xs)...), function (Δ)
|
|
||||||
start = 0
|
|
||||||
Δs = [begin
|
|
||||||
d = if ndims(xsi) == 1
|
|
||||||
Δ[:, start+1]
|
|
||||||
else
|
|
||||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
|
||||||
Δ[:, start+1:start+size(xsi,2), i...]
|
|
||||||
end
|
|
||||||
start += size(xsi, 2)
|
|
||||||
d
|
|
||||||
end for xsi in xs]
|
|
||||||
return (Δs...,)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
|
|
||||||
cnames = map(_ -> gensym(), c)
|
|
||||||
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
|
|
||||||
track(cat, $(cnames...), x, xs..., dims = dims)
|
|
||||||
end
|
|
||||||
|
|
||||||
@grad function cat(Xs...; dims)
|
|
||||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
|
||||||
start = ntuple(i -> 0, Val(ndims(Δ)))
|
|
||||||
Δs = [begin
|
|
||||||
dim_xs = 1:ndims(xs)
|
|
||||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
|
||||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
|
||||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
|
||||||
start = start .+ till_xs
|
|
||||||
d
|
|
||||||
end for xs in Xs]
|
|
||||||
return (Δs...,)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
|
||||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
|
||||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
|
||||||
|
|
||||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
|
||||||
|
|
||||||
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
|
|
||||||
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
|
|
||||||
|
|
||||||
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
|
|
||||||
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
|
|
||||||
|
|
||||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
|
||||||
m1, n1 = size(mat1)
|
|
||||||
mat1_rsh = reshape(mat1,(1,m1,1,n1))
|
|
||||||
|
|
||||||
m2, n2 = size(mat2)
|
|
||||||
mat2_rsh = reshape(mat2,(m2,1,n2,1))
|
|
||||||
|
|
||||||
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|
||||||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
|
||||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|
||||||
|
|
||||||
|
|
||||||
inv(A::TrackedArray) = Tracker.track(inv, A)
|
|
||||||
@grad function inv(A)
|
|
||||||
return inv(Tracker.data(A)), function (Δ)
|
|
||||||
Ainv = inv(A)
|
|
||||||
∇A = - Ainv' * Δ * Ainv'
|
|
||||||
return (∇A, )
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# (/) rdivide
|
|
||||||
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
|
||||||
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
|
||||||
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
|
||||||
@grad function (A / B)
|
|
||||||
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
|
||||||
Binv = inv(B)
|
|
||||||
∇B = - Binv' * A' * Δ * Binv'
|
|
||||||
return (Δ * Binv', ∇B)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
|
||||||
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
|
||||||
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
|
||||||
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
|
||||||
@grad function (A \ B)
|
|
||||||
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
|
||||||
Ainv = inv(A)
|
|
||||||
∇A = - Ainv' * Δ * B' * Ainv'
|
|
||||||
return (∇A, Ainv' * Δ)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
# Reductions
|
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
|
||||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
|
||||||
|
|
||||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
|
||||||
Δ -> (zero(xs) .+ Δ, )
|
|
||||||
|
|
||||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
|
||||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
|
||||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
|
||||||
|
|
||||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
|
||||||
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
|
||||||
Δ -> (nobacksies(:sum,
|
|
||||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
|
||||||
nothing)
|
|
||||||
|
|
||||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
|
||||||
|
|
||||||
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
|
||||||
|
|
||||||
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
|
||||||
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
|
||||||
|
|
||||||
import LinearAlgebra: dot
|
|
||||||
|
|
||||||
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
||||||
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
||||||
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
|
||||||
|
|
||||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
|
||||||
|
|
||||||
# Hacks to get std working
|
|
||||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected)
|
|
||||||
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
|
|
||||||
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))
|
|
||||||
|
|
||||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
|
||||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
|
||||||
|
|
||||||
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
|
||||||
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
|
||||||
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
|
||||||
|
|
||||||
@grad function maximum(xs; dims = dims)
|
|
||||||
maximum(data(xs), dims = dims), function (Δ)
|
|
||||||
Δ′ = zero(xs)
|
|
||||||
_, i = findmax(data(xs), dims = dims)
|
|
||||||
Δ′[i] = data(Δ)
|
|
||||||
return (nobacksies(:maximum, Δ′),)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@grad function minimum(xs; dims = dims)
|
|
||||||
minimum(data(xs), dims = dims), function (Δ)
|
|
||||||
Δ′ = zero(xs)
|
|
||||||
_, i = findmin(data(xs), dims = dims)
|
|
||||||
Δ′[i] = data(Δ)
|
|
||||||
return (nobacksies(:minimum, Δ′),)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# BLAS
|
|
||||||
|
|
||||||
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
|
|
||||||
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
|
|
||||||
|
|
||||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
|
||||||
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
|
||||||
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
|
||||||
|
|
||||||
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
|
||||||
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
|
||||||
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
|
||||||
|
|
||||||
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
|
||||||
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
|
||||||
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
|
||||||
|
|
||||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
|
||||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
|
||||||
|
|
||||||
# NNlib
|
|
||||||
|
|
||||||
using NNlib
|
|
||||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool
|
|
||||||
|
|
||||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
|
||||||
|
|
||||||
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
|
|
||||||
|
|
||||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|
||||||
|
|
||||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
|
||||||
|
|
||||||
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
|
||||||
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
|
||||||
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
|
||||||
|
|
||||||
@grad depthwiseconv(x, w; kw...) =
|
|
||||||
depthwiseconv(data(x), data(w); kw...),
|
|
||||||
Δ -> nobacksies(:depthwiseconv,
|
|
||||||
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
|
||||||
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
|
||||||
|
|
||||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
|
||||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
|
||||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
|
||||||
|
|
||||||
@grad conv(x, w; kw...) =
|
|
||||||
conv(data(x), data(w); kw...),
|
|
||||||
Δ -> nobacksies(:conv,
|
|
||||||
(NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...),
|
|
||||||
NNlib.∇conv_filter(data.((Δ, x))...; size=size(w), kw...)))
|
|
||||||
|
|
||||||
∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
|
|
||||||
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
|
|
||||||
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)
|
|
||||||
|
|
||||||
@grad ∇conv_data(x, w; kw...) =
|
|
||||||
∇conv_data(data(x), data(w); kw...),
|
|
||||||
Δ -> nobacksies(:conv,
|
|
||||||
(NNlib.conv(data.((Δ, w))...; size=size(x), kw...),
|
|
||||||
NNlib.∇conv_filter(data.((x, Δ))...; size=size(w), kw...)))
|
|
||||||
|
|
||||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
|
||||||
|
|
||||||
@grad function maxpool(x, k; kw...)
|
|
||||||
y = maxpool(data(x), k; kw...)
|
|
||||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
|
||||||
end
|
|
||||||
|
|
||||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
|
||||||
|
|
||||||
@grad function meanpool(x, k; kw...)
|
|
||||||
y = meanpool(data(x), k; kw...)
|
|
||||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Broadcasting
|
|
||||||
|
|
||||||
using ForwardDiff: Dual, partials, value
|
|
||||||
|
|
||||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
|
||||||
|
|
||||||
unbroadcast(x::AbstractArray, Δ) =
|
|
||||||
size(x) == size(Δ) ? Δ :
|
|
||||||
length(x) == length(Δ) ? trim(x, Δ) :
|
|
||||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
|
||||||
|
|
||||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
|
||||||
unbroadcast(x::Base.RefValue, _) = nothing
|
|
||||||
|
|
||||||
dual(x, p) = x
|
|
||||||
dual(x::Real, p) = Dual(x, p)
|
|
||||||
|
|
||||||
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
|
||||||
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
|
||||||
return Δ * f(dargs...).partials[1]
|
|
||||||
end
|
|
||||||
|
|
||||||
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
|
||||||
y = broadcast(f, data.(args)...)
|
|
||||||
eltype(y) <: Real || return y
|
|
||||||
eltype(y) == Bool && return y
|
|
||||||
function back(Δ)
|
|
||||||
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
|
||||||
dxs = map(unbroadcast, args, Δargs)
|
|
||||||
return dxs
|
|
||||||
end
|
|
||||||
# So we can return non-tracked arrays
|
|
||||||
track(Call(back, tracker.(args)), y)
|
|
||||||
end
|
|
||||||
|
|
||||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
|
||||||
|
|
||||||
struct TrackedStyle <: BroadcastStyle end
|
|
||||||
|
|
||||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
|
||||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
|
||||||
|
|
||||||
# We have to re-build the original broadcast struct to get the appropriate array
|
|
||||||
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
|
||||||
broadcast_rebuild(xs) = data(xs)
|
|
||||||
|
|
||||||
broadcast_rebuild(bc::Broadcasted) =
|
|
||||||
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
|
||||||
|
|
||||||
preprocess(x) = x
|
|
||||||
|
|
||||||
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
|
||||||
bc1 = Broadcast.flatten(bc)
|
|
||||||
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
|
||||||
∇broadcast(bc2.f, bc1.args...)
|
|
||||||
end
|
|
||||||
|
|
||||||
using Requires
|
|
||||||
|
|
||||||
# https://github.com/FluxML/Flux.jl/issues/353
|
|
||||||
if VERSION < v"1.1.0-DEV.548"
|
|
||||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
|
||||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
|
||||||
isflat(bc) && return bc
|
|
||||||
args = cat_nested(bc)
|
|
||||||
let makeargs = make_makeargs(bc), f = bc.f
|
|
||||||
newf = @inline function(args::Vararg{Any,N}) where N
|
|
||||||
f(makeargs(args...)...)
|
|
||||||
end
|
|
||||||
return Broadcasted{Style}(newf, args, bc.axes)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
|
||||||
bc = t[1]
|
|
||||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
|
||||||
let makeargs = make_makeargs(makeargs, bc.args)
|
|
||||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
|
||||||
return @inline function(args::Vararg{Any,N}) where N
|
|
||||||
args1 = makeargs(args...)
|
|
||||||
a, b = headargs(args1...), tailargs(args1...)
|
|
||||||
(f(a...), b...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,160 +0,0 @@
|
|||||||
mutable struct TrackedReal{T<:Real} <: Real
|
|
||||||
data::T
|
|
||||||
tracker::Tracked{T}
|
|
||||||
end
|
|
||||||
|
|
||||||
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
|
|
||||||
|
|
||||||
data(x::TrackedReal) = x.data
|
|
||||||
tracker(x::TrackedReal) = x.tracker
|
|
||||||
|
|
||||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
|
||||||
|
|
||||||
function back!(x::TrackedReal; once = true)
|
|
||||||
isinf(x) && error("Loss is Inf")
|
|
||||||
isnan(x) && error("Loss is NaN")
|
|
||||||
return back!(x, 1, once = once)
|
|
||||||
end
|
|
||||||
|
|
||||||
function update!(x::TrackedReal, Δ)
|
|
||||||
x.data += data(Δ)
|
|
||||||
tracker(x).grad = 0
|
|
||||||
return x
|
|
||||||
end
|
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedReal)
|
|
||||||
T = get(io, :typeinfo, Any)
|
|
||||||
show(io, data(x))
|
|
||||||
T <: TrackedReal || print(io, " (tracked)")
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
|
||||||
|
|
||||||
Base.copy(x::TrackedReal) = x
|
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
|
||||||
error("Not implemented: convert tracked $S to tracked $T")
|
|
||||||
|
|
||||||
(T::Type{<:TrackedReal})(x::Real) = convert(T, x)
|
|
||||||
|
|
||||||
for op in [:(==), :≈, :<, :(<=)]
|
|
||||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
|
||||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
|
||||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.eps(x::TrackedReal) = eps(data(x))
|
|
||||||
Base.eps(::Type{TrackedReal{T}}) where T = eps(T)
|
|
||||||
|
|
||||||
for f in :[isinf, isnan, isfinite].args
|
|
||||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
|
|
||||||
|
|
||||||
Base.float(x::TrackedReal) = x
|
|
||||||
|
|
||||||
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
|
||||||
TrackedReal{promote_type(S,T)}
|
|
||||||
|
|
||||||
using Random
|
|
||||||
|
|
||||||
for f in :[rand, randn, randexp].args
|
|
||||||
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
|
|
||||||
end
|
|
||||||
|
|
||||||
using DiffRules, SpecialFunctions, NaNMath
|
|
||||||
|
|
||||||
for (M, f, arity) in DiffRules.diffrules()
|
|
||||||
arity == 1 || continue
|
|
||||||
@eval begin
|
|
||||||
@grad $M.$f(a::Real) =
|
|
||||||
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
|
|
||||||
$M.$f(a::TrackedReal) = track($M.$f, a)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# Work around zero(π) not working, for some reason
|
|
||||||
_zero(::Irrational) = nothing
|
|
||||||
_zero(x) = zero(x)
|
|
||||||
|
|
||||||
for (M, f, arity) in DiffRules.diffrules()
|
|
||||||
arity == 2 || continue
|
|
||||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
|
||||||
f = :($M.$f)
|
|
||||||
@eval begin
|
|
||||||
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
|
||||||
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
|
|
||||||
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
|
|
||||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
|
||||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
|
||||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# Eliminating ambiguity
|
|
||||||
import Base:^
|
|
||||||
|
|
||||||
^(a::TrackedReal, b::Integer) = track(^, a, b)
|
|
||||||
|
|
||||||
# Hack for conversions
|
|
||||||
|
|
||||||
using ForwardDiff: Dual
|
|
||||||
|
|
||||||
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
|
|
||||||
(Dual{T,V,N})(x::Dual) where {T,V,N} = invoke(Dual{T,V,N}, Tuple{Number}, x)
|
|
||||||
|
|
||||||
# Tuples
|
|
||||||
|
|
||||||
struct TrackedTuple{T<:Tuple}
|
|
||||||
data::T
|
|
||||||
tracker::Tracked{T}
|
|
||||||
end
|
|
||||||
|
|
||||||
data(xs::TrackedTuple) = xs.data
|
|
||||||
tracker(xs::TrackedTuple) = xs.tracker
|
|
||||||
|
|
||||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
|
||||||
init_grad(x::Tuple) = init_grad.(x)
|
|
||||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
|
||||||
|
|
||||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
|
||||||
|
|
||||||
function Base.show(io::IO, xs::TrackedTuple)
|
|
||||||
show(io, data(xs))
|
|
||||||
print(io, " (tracked)")
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.length(x::TrackedTuple) = length(data(x))
|
|
||||||
|
|
||||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
|
||||||
|
|
||||||
@grad function getindex(xs::TrackedTuple, i)
|
|
||||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Array collection
|
|
||||||
|
|
||||||
function collect(xs)
|
|
||||||
xs = Base.collect(xs)
|
|
||||||
track(Call(collect, (tracker.(xs),)), data.(xs))
|
|
||||||
end
|
|
||||||
|
|
||||||
function scan(c::Call{typeof(collect)})
|
|
||||||
foreach(scan, c.args[1])
|
|
||||||
end
|
|
||||||
|
|
||||||
function back_(c::Call{typeof(collect)}, Δ, once)
|
|
||||||
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
|
|
||||||
end
|
|
||||||
|
|
||||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
|
||||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
|
||||||
end
|
|
||||||
|
|
||||||
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
|
|
||||||
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)
|
|
@ -1,18 +0,0 @@
|
|||||||
function ngradient(f, xs::AbstractArray...)
|
|
||||||
grads = zero.(xs)
|
|
||||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
|
||||||
δ = sqrt(eps())
|
|
||||||
tmp = x[i]
|
|
||||||
x[i] = tmp - δ/2
|
|
||||||
y1 = f(xs...)
|
|
||||||
x[i] = tmp + δ/2
|
|
||||||
y2 = f(xs...)
|
|
||||||
x[i] = tmp
|
|
||||||
Δ[i] = (y2-y1)/δ
|
|
||||||
end
|
|
||||||
return grads
|
|
||||||
end
|
|
||||||
|
|
||||||
gradcheck(f, xs...) =
|
|
||||||
all(isapprox.(ngradient(f, xs...),
|
|
||||||
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))
|
|
@ -47,5 +47,7 @@ end
|
|||||||
if CuArrays.libcudnn != nothing
|
if CuArrays.libcudnn != nothing
|
||||||
@info "Testing Flux/CUDNN"
|
@info "Testing Flux/CUDNN"
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
|
if !haskey(ENV, "CI_DISABLE_CURNN_TEST")
|
||||||
include("curnn.jl")
|
include("curnn.jl")
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
@ -14,3 +14,9 @@ using Test
|
|||||||
@test FashionMNIST.labels() isa Vector{Int64}
|
@test FashionMNIST.labels() isa Vector{Int64}
|
||||||
|
|
||||||
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
|
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
|
||||||
|
|
||||||
|
@test Iris.features() isa Matrix
|
||||||
|
@test size(Iris.features()) == (4,150)
|
||||||
|
|
||||||
|
@test Iris.labels() isa Vector{String}
|
||||||
|
@test size(Iris.labels()) == (150,)
|
||||||
|
@ -1,6 +1,18 @@
|
|||||||
using Test, Random
|
using Test, Random
|
||||||
|
import Flux: activations
|
||||||
|
|
||||||
@testset "basic" begin
|
@testset "basic" begin
|
||||||
|
@testset "helpers" begin
|
||||||
|
@testset "activations" begin
|
||||||
|
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||||
|
x = rand(10)
|
||||||
|
@test activations(Chain(), x) == []
|
||||||
|
@test activations(dummy_model, x)[1] == dummy_model[1](x)
|
||||||
|
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
|
||||||
|
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Chain" begin
|
@testset "Chain" begin
|
||||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||||
@ -30,4 +42,34 @@ using Test, Random
|
|||||||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Maxout" begin
|
||||||
|
# Note that the normal common usage of Maxout is as per the docstring
|
||||||
|
# These are abnormal constructors used for testing purposes
|
||||||
|
|
||||||
|
@testset "Constructor" begin
|
||||||
|
mo = Maxout(() -> identity, 4)
|
||||||
|
input = rand(40)
|
||||||
|
@test mo(input) == input
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "simple alternatives" begin
|
||||||
|
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
|
||||||
|
input = rand(40)
|
||||||
|
@test mo(input) == 2*input
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "complex alternatives" begin
|
||||||
|
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
|
||||||
|
input = [3.0 2.0]
|
||||||
|
target = [0.5, 0.7].*input
|
||||||
|
@test mo(input) == target
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "params" begin
|
||||||
|
mo = Maxout(()->Dense(32, 64), 4)
|
||||||
|
ps = params(mo)
|
||||||
|
@test length(ps) == 8 #4 alts, each with weight and bias
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
@ -4,9 +4,9 @@ using Flux: maxpool, meanpool
|
|||||||
@testset "Pooling" begin
|
@testset "Pooling" begin
|
||||||
x = randn(Float32, 10, 10, 3, 2)
|
x = randn(Float32, 10, 10, 3, 2)
|
||||||
mp = MaxPool((2, 2))
|
mp = MaxPool((2, 2))
|
||||||
@test mp(x) == maxpool(x, (2,2))
|
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
||||||
mp = MeanPool((2, 2))
|
mp = MeanPool((2, 2))
|
||||||
@test mp(x) == meanpool(x, (2,2))
|
@test mp(x) == meanpool(x, PoolDims(x, 2))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "CNN" begin
|
@testset "CNN" begin
|
||||||
@ -22,14 +22,42 @@ end
|
|||||||
@test size(m(r)) == (10, 5)
|
@test size(m(r)) == (10, 5)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "asymmetric padding" begin
|
||||||
|
r = ones(Float32, 28, 28, 1, 1)
|
||||||
|
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
|
||||||
|
m.weight.data[:] .= 1.0
|
||||||
|
m.bias.data[:] .= 0.0
|
||||||
|
y_hat = Flux.data(m(r))[:,:,1,1]
|
||||||
|
@test size(y_hat) == (27, 29)
|
||||||
|
@test y_hat[1, 1] ≈ 6.0
|
||||||
|
@test y_hat[2, 2] ≈ 9.0
|
||||||
|
@test y_hat[end, 1] ≈ 4.0
|
||||||
|
@test y_hat[1, end] ≈ 3.0
|
||||||
|
@test y_hat[1, end-1] ≈ 6.0
|
||||||
|
@test y_hat[end, end] ≈ 2.0
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Depthwise Conv" begin
|
@testset "Depthwise Conv" begin
|
||||||
r = zeros(Float32, 28, 28, 3, 5)
|
r = zeros(Float32, 28, 28, 3, 5)
|
||||||
|
|
||||||
m1 = DepthwiseConv((2, 2), 3=>5)
|
m1 = DepthwiseConv((2, 2), 3=>5)
|
||||||
|
|
||||||
@test size(m1(r), 3) == 15
|
@test size(m1(r), 3) == 15
|
||||||
|
|
||||||
m2 = DepthwiseConv((2, 2), 3)
|
m2 = DepthwiseConv((2, 2), 3)
|
||||||
|
|
||||||
@test size(m2(r), 3) == 3
|
@test size(m2(r), 3) == 3
|
||||||
|
|
||||||
|
x = zeros(Float64, 28, 28, 3, 5)
|
||||||
|
|
||||||
|
m3 = DepthwiseConv((2, 2), 3 => 5)
|
||||||
|
|
||||||
|
@test size(m3(r), 3) == 15
|
||||||
|
|
||||||
|
m4 = DepthwiseConv((2, 2), 3)
|
||||||
|
|
||||||
|
@test size(m4(r), 3) == 3
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "ConvTranspose" begin
|
||||||
|
x = zeros(Float32, 28, 28, 1, 1)
|
||||||
|
y = Conv((3,3), 1 => 1)(x)
|
||||||
|
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||||
|
@test size(x_hat) == size(x)
|
||||||
end
|
end
|
||||||
|
@ -104,3 +104,210 @@ end
|
|||||||
@test (@allocated m(x)) < 100_000_000
|
@test (@allocated m(x)) < 100_000_000
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
@testset "InstanceNorm" begin
|
||||||
|
# helper functions
|
||||||
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
|
# begin tests
|
||||||
|
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
|
||||||
|
@test m.β.data == [0, 0] # initβ(2)
|
||||||
|
@test m.γ.data == [1, 1] # initγ(2)
|
||||||
|
|
||||||
|
@test m.active
|
||||||
|
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
#julia> x
|
||||||
|
#[:, :, 1] =
|
||||||
|
# 1.0 4.0
|
||||||
|
# 2.0 5.0
|
||||||
|
# 3.0 6.0
|
||||||
|
#
|
||||||
|
#[:, :, 2] =
|
||||||
|
# 7.0 10.0
|
||||||
|
# 8.0 11.0
|
||||||
|
# 9.0 12.0
|
||||||
|
#
|
||||||
|
# μ will be
|
||||||
|
# (1. + 2. + 3.) / 3 = 2.
|
||||||
|
# (4. + 5. + 6.) / 3 = 5.
|
||||||
|
#
|
||||||
|
# (7. + 8. + 9.) / 3 = 8.
|
||||||
|
# (10. + 11. + 12.) / 3 = 11.
|
||||||
|
#
|
||||||
|
# ∴ update rule with momentum:
|
||||||
|
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||||
|
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||||
|
@test m.μ ≈ [0.5, 0.8]
|
||||||
|
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||||
|
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||||
|
# 2-element Array{Float64,1}:
|
||||||
|
# 1.
|
||||||
|
# 1.
|
||||||
|
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
x′ = m(x).data
|
||||||
|
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||||
|
end
|
||||||
|
# with activation function
|
||||||
|
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
|
||||||
|
affine_shape = collect(sizes)
|
||||||
|
affine_shape[1] = 1
|
||||||
|
|
||||||
|
@test m.active
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
y = m(x).data
|
||||||
|
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||||
|
end
|
||||||
|
|
||||||
|
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
|
y = reshape(m(y), sizes...)
|
||||||
|
@test m(x) == y
|
||||||
|
end
|
||||||
|
|
||||||
|
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||||
|
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
y = m(x)
|
||||||
|
@test size(m.μ) == (sizes[end - 1], )
|
||||||
|
@test size(m.σ²) == (sizes[end - 1], )
|
||||||
|
@test size(y) == sizes
|
||||||
|
end
|
||||||
|
|
||||||
|
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||||
|
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||||
|
end
|
||||||
|
|
||||||
|
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||||
|
m(x)
|
||||||
|
@test (@allocated m(x)) < 100_000_000
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "GroupNorm" begin
|
||||||
|
# begin tests
|
||||||
|
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
||||||
|
|
||||||
|
let m = GroupNorm(4,2), sizes = (3,4,2),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
|
||||||
|
@test m.β.data == [0, 0, 0, 0] # initβ(32)
|
||||||
|
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
|
||||||
|
|
||||||
|
@test m.active
|
||||||
|
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
#julia> x
|
||||||
|
#[:, :, 1] =
|
||||||
|
# 1.0 4.0 7.0 10.0
|
||||||
|
# 2.0 5.0 8.0 11.0
|
||||||
|
# 3.0 6.0 9.0 12.0
|
||||||
|
#
|
||||||
|
#[:, :, 2] =
|
||||||
|
# 13.0 16.0 19.0 22.0
|
||||||
|
# 14.0 17.0 20.0 23.0
|
||||||
|
# 15.0 18.0 21.0 24.0
|
||||||
|
#
|
||||||
|
# μ will be
|
||||||
|
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
|
||||||
|
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
|
||||||
|
#
|
||||||
|
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
|
||||||
|
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
|
||||||
|
#
|
||||||
|
# μ =
|
||||||
|
# 3.5 15.5
|
||||||
|
# 9.5 21.5
|
||||||
|
#
|
||||||
|
# ∴ update rule with momentum:
|
||||||
|
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
|
||||||
|
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
|
||||||
|
@test m.μ ≈ [0.95, 1.55]
|
||||||
|
|
||||||
|
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
|
||||||
|
# 2-element Array{Tracker.TrackedReal{Float64},1}:
|
||||||
|
# 1.25
|
||||||
|
# 1.25
|
||||||
|
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
x′ = m(x).data
|
||||||
|
println(x′[1])
|
||||||
|
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
||||||
|
end
|
||||||
|
# with activation function
|
||||||
|
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
|
||||||
|
μ_affine_shape = ones(Int,length(sizes) + 1)
|
||||||
|
μ_affine_shape[end-1] = 2 # Number of groups
|
||||||
|
|
||||||
|
affine_shape = ones(Int,length(sizes) + 1)
|
||||||
|
affine_shape[end-2] = 2 # Channels per group
|
||||||
|
affine_shape[end-1] = 2 # Number of groups
|
||||||
|
affine_shape[1] = sizes[1]
|
||||||
|
affine_shape[end] = sizes[end]
|
||||||
|
|
||||||
|
og_shape = size(x)
|
||||||
|
|
||||||
|
@test m.active
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
x_ = reshape(x,affine_shape...)
|
||||||
|
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
|
||||||
|
@test isapprox(y, out, atol = 1.0e-7)
|
||||||
|
end
|
||||||
|
|
||||||
|
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
|
y = reshape(m(y), sizes...)
|
||||||
|
@test m(x) == y
|
||||||
|
end
|
||||||
|
|
||||||
|
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||||
|
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
y = m(x)
|
||||||
|
@test size(m.μ) == (m.G,1)
|
||||||
|
@test size(m.σ²) == (m.G,1)
|
||||||
|
@test size(y) == sizes
|
||||||
|
end
|
||||||
|
|
||||||
|
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||||
|
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
@test IN(x) ≈ GN(x)
|
||||||
|
end
|
||||||
|
|
||||||
|
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||||
|
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
@test BN(x) ≈ GN(x)
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
||||||
|
@ -4,21 +4,15 @@ using Flux.Tracker
|
|||||||
using Test
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||||
|
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||||
|
Momentum()]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Opt(0.001)
|
|
||||||
if opt isa Descent || opt isa ADAGrad
|
|
||||||
opt = Opt(0.1)
|
|
||||||
end
|
|
||||||
if opt isa ADADelta
|
|
||||||
opt = Opt(0.9)
|
|
||||||
end
|
|
||||||
for t = 1: 10^5
|
for t = 1: 10^5
|
||||||
l = loss(rand(10))
|
θ = Params([w′])
|
||||||
back!(l)
|
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
Optimise.update!(opt, θ, θ̄)
|
||||||
w′.data .-= delta
|
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
end
|
end
|
||||||
|
338
test/tracker.jl
338
test/tracker.jl
@ -1,347 +1,15 @@
|
|||||||
using Flux
|
using Flux, Test
|
||||||
using Flux.Tracker, Test, NNlib
|
using Tracker: gradcheck
|
||||||
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
|
|
||||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
|
||||||
using Printf: @sprintf
|
|
||||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
|
|
||||||
using Statistics: mean, std
|
|
||||||
using Random
|
|
||||||
# using StatsBase
|
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||||
@testset "Tracker" begin
|
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
|
||||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
|
||||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
|
||||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
|
||||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
|
||||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
|
||||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
|
||||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
|
||||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
|
||||||
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
|
||||||
@test gradtest(x -> prod(x), (3,4,5))
|
|
||||||
|
|
||||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
@testset "Tracker" begin
|
||||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
|
||||||
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
|
||||||
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
|
||||||
|
|
||||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
|
||||||
|
|
||||||
@test gradtest(det, (4, 4))
|
|
||||||
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
|
|
||||||
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
|
|
||||||
|
|
||||||
@testset "indexing & slicing" begin
|
|
||||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
|
||||||
end
|
|
||||||
|
|
||||||
function promotiontest(f, A, B, C)
|
|
||||||
r0 = f(A, B, C)
|
|
||||||
r1 = f(param(A), B, C)
|
|
||||||
r2 = f(A, param(B), C)
|
|
||||||
r3 = f(A, B, param(C))
|
|
||||||
r4 = f(param(A), param(B), param(C))
|
|
||||||
|
|
||||||
@test !isa(r0, TrackedArray)
|
|
||||||
@test all(isa.([r1,r2,r3,r4], TrackedArray))
|
|
||||||
@test r1 == r2 == r3 == r4
|
|
||||||
@test r0 == Flux.data(r4)
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "concat" begin
|
|
||||||
cat1(x...) = cat(x..., dims = 1)
|
|
||||||
cat2(x...) = cat(x..., dims = 2)
|
|
||||||
|
|
||||||
@testset for vcatf in [vcat, cat1]
|
|
||||||
@test gradtest(vcatf, rand(5), rand(3))
|
|
||||||
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
|
||||||
@test gradtest(vcatf, rand(5)', rand(5)')
|
|
||||||
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
|
||||||
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
|
||||||
@test gradtest(vcatf, rand(5), rand(3,1))
|
|
||||||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
@testset for hcatf in [hcat, cat2]
|
|
||||||
@test gradtest(hcatf, rand(5), rand(5))
|
|
||||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
|
||||||
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
|
||||||
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
|
||||||
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
|
||||||
@test gradtest(hcatf, rand(5)', rand(1,3))
|
|
||||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
|
||||||
@test gradtest(catf, rand(5))
|
|
||||||
@test gradtest(catf, rand(5)')
|
|
||||||
@test gradtest(catf, rand(2,5))
|
|
||||||
@test gradtest(catf, rand(2,5,3))
|
|
||||||
end
|
|
||||||
|
|
||||||
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
|
||||||
|
|
||||||
@testset "cat($dim, ...)" for dim in 3:5
|
|
||||||
catdim = (x...) -> cat(x..., dims = dim)
|
|
||||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
|
||||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
|
||||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
|
||||||
end
|
|
||||||
|
|
||||||
@test !isa(vcat(rand(2)), TrackedArray)
|
|
||||||
@test !isa(hcat(rand(2)), TrackedArray)
|
|
||||||
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
|
||||||
|
|
||||||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
|
||||||
|
|
||||||
@testset "promotiontest" begin
|
|
||||||
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
|
||||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
|
||||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
|
||||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
|
||||||
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
|
|
||||||
end
|
|
||||||
|
|
||||||
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
|
||||||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
|
||||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
|
||||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
|
||||||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "scalars" begin
|
|
||||||
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
|
|
||||||
@test vcat(1, param([1, 2, 3])) isa TrackedArray
|
|
||||||
@test hcat(1, param([1 2 3;])) isa TrackedArray
|
|
||||||
@test vcat(param(1), 2) isa TrackedArray
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
|
||||||
|
|
||||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
|
||||||
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
|
|
||||||
|
|
||||||
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
|
||||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
|
||||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
|
||||||
|
|
||||||
@test gradtest(kron, rand(5), rand(3))
|
|
||||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
|
||||||
@test gradtest(kron, rand(5,1), rand(3,1))
|
|
||||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
|
||||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
|
||||||
|
|
||||||
@test gradtest(x -> diagm(0 => x), rand(3))
|
|
||||||
|
|
||||||
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
|
||||||
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
|
||||||
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
|
||||||
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
|
||||||
|
|
||||||
@testset "mean" begin
|
|
||||||
@test gradtest(mean, rand(2, 3))
|
|
||||||
|
|
||||||
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
|
||||||
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
|
||||||
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
|
||||||
|
|
||||||
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "maximum" begin
|
|
||||||
@test gradtest(maximum, rand(2, 3))
|
|
||||||
|
|
||||||
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
|
||||||
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
|
||||||
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
|
||||||
|
|
||||||
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "minimum" begin
|
|
||||||
@test gradtest(minimum, rand(2, 3))
|
|
||||||
|
|
||||||
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
|
||||||
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
|
||||||
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
|
||||||
|
|
||||||
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
|
||||||
end
|
|
||||||
|
|
||||||
@test gradtest(x -> std(x), rand(5,5))
|
|
||||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
|
||||||
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
|
|
||||||
|
|
||||||
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
||||||
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
||||||
|
|
||||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
|
||||||
@test gradtest(dot, rand(5), rand(5))
|
|
||||||
|
|
||||||
@test gradtest(norm, rand(5))
|
|
||||||
|
|
||||||
@test gradtest(rand(5)) do x
|
|
||||||
y = x.^2
|
|
||||||
2y + x
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64, 2, 3, 2))
|
|
||||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2))
|
|
||||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2))
|
|
||||||
|
|
||||||
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
|
||||||
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
|
|
||||||
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
|
|
||||||
|
|
||||||
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
|
|
||||||
|
|
||||||
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
|
||||||
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3))
|
|
||||||
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3))
|
|
||||||
|
|
||||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
|
||||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
|
||||||
|
|
||||||
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
|
||||||
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
|
||||||
|
|
||||||
@test gradtest(x -> Float64.(x), 5)
|
|
||||||
|
|
||||||
@testset "equality & order" begin
|
|
||||||
# TrackedReal
|
|
||||||
@test param(2)^2 == param(4)
|
|
||||||
@test param(2)^2 == 4
|
|
||||||
@test 4 == param(2)^2
|
|
||||||
|
|
||||||
@test param(2)^2 ≈ param(4)
|
|
||||||
@test param(2)^2 ≈ 4
|
|
||||||
@test 4 ≈ param(2)^2
|
|
||||||
|
|
||||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
|
||||||
@test (param([1,2,3]) .<= 2) == [true, true, false]
|
|
||||||
@test (2 .> param([1,2,3])) == [true, false, false]
|
|
||||||
@test (2 .>= param([1,2,3])) == [true, true, false]
|
|
||||||
|
|
||||||
# TrackedArray
|
|
||||||
@test param([1,2,3]).^2 == param([1,4,9])
|
|
||||||
@test [1,2,3].^2 == param([1,4,9])
|
|
||||||
@test param([1,2,3]).^2 == [1,4,9]
|
|
||||||
|
|
||||||
@test param([1,2,3]).^2 ≈ param([1,4,9])
|
|
||||||
@test [1,2,3].^2 ≈ param([1,4,9])
|
|
||||||
@test param([1,2,3]).^2 ≈ [1,4,9]
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "reshape" begin
|
|
||||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
|
||||||
@test x isa TrackedArray
|
|
||||||
@test size(x) == (4,2)
|
|
||||||
x = reshape(param([1]), (1,:))
|
|
||||||
@test x isa TrackedArray
|
|
||||||
@test size(x) == (1,1)
|
|
||||||
x = reshape(param(rand(2)), (2,:))
|
|
||||||
@test x isa TrackedArray
|
|
||||||
@test size(x) == (2,1)
|
|
||||||
x = reshape(param(rand(2,2)), (1,:,2))
|
|
||||||
@test x isa TrackedArray
|
|
||||||
@test size(x) == (1,2,2)
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Intermediates" begin
|
|
||||||
x = param([1])
|
|
||||||
l = sum((x .+ x).^2)
|
|
||||||
Flux.back!(l, once = false)
|
|
||||||
@test x.grad == [8]
|
|
||||||
x.grad .= 0
|
|
||||||
Flux.back!(l, once = false)
|
|
||||||
@test x.grad == [8]
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Fallbacks" begin
|
|
||||||
xs = param([1 2; 3 4])
|
|
||||||
@test similar(xs) isa Matrix{Float64}
|
|
||||||
end
|
|
||||||
|
|
||||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
|
||||||
|
|
||||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
|
||||||
|
|
||||||
b = param(rand())
|
|
||||||
Tracker.back!(b)
|
|
||||||
@test Tracker.grad(b) == 1
|
|
||||||
|
|
||||||
@testset "collect" begin
|
|
||||||
x, y = param(2), param(3)
|
|
||||||
xy = Tracker.collect([x, y])
|
|
||||||
@test xy isa TrackedArray{Float64}
|
|
||||||
z = xy[1]*xy[2]
|
|
||||||
back!(z)
|
|
||||||
@test grad.((x,y)) == (3, 2)
|
|
||||||
|
|
||||||
@test gradient(2, 3) do x, y
|
|
||||||
xy = Tracker.collect([x, y])
|
|
||||||
xy[1]*xy[2]
|
|
||||||
end == (3, 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Gradient Hooks
|
|
||||||
@testset "Hooks" begin
|
|
||||||
x = param(2)
|
|
||||||
y = Tracker.hook(-, x)
|
|
||||||
back!(y)
|
|
||||||
@test grad(x) == -1
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Checkpointing" begin
|
|
||||||
count = 0
|
|
||||||
function mul(a, b)
|
|
||||||
count += 1
|
|
||||||
a * b
|
|
||||||
end
|
|
||||||
@test gradient(x -> mul(5, x), 3)[1] == 5
|
|
||||||
@test count == 1
|
|
||||||
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
|
|
||||||
@test count == 3
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Updates" begin
|
|
||||||
xs = param([1, 2, 3])
|
|
||||||
Tracker.update!(xs, param([4, 5, 6]))
|
|
||||||
@test xs == [5, 7, 9]
|
|
||||||
x = param(3)
|
|
||||||
Tracker.update!(x, param(4))
|
|
||||||
@test x == 7
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Params" begin
|
|
||||||
W = param(randn(5, 10))
|
|
||||||
x = rand(10)
|
|
||||||
dW = gradient(W -> sum(W*x), W)[1]
|
|
||||||
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
|
|
||||||
@test gs[W] == dW
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Forward" begin
|
|
||||||
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
|
|
||||||
reshape(ones(25), :, 1)
|
|
||||||
@test gradient([2, 3]) do x
|
|
||||||
forwarddiff(x) do x
|
|
||||||
x[1]*x[2]
|
|
||||||
end
|
|
||||||
end == ([3, 2],)
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Custom Sensitivities" begin
|
|
||||||
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
|
|
||||||
@test back([1, 1]) == (32,)
|
|
||||||
end
|
|
||||||
|
|
||||||
end #testset
|
|
||||||
|
@ -87,6 +87,12 @@ end
|
|||||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Basic Stacking" begin
|
||||||
|
x = randn(3,3)
|
||||||
|
stacked = stack([x, x], 2)
|
||||||
|
@test size(stacked) == (3,2,3)
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Precision" begin
|
@testset "Precision" begin
|
||||||
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||||
x = rand(10)
|
x = rand(10)
|
||||||
|
Loading…
Reference in New Issue
Block a user