Merge remote-tracking branch 'upstream/master' into samepad
This commit is contained in:
commit
755536bf5e
|
@ -6,6 +6,7 @@ os:
|
||||||
# - osx
|
# - osx
|
||||||
|
|
||||||
julia:
|
julia:
|
||||||
|
- 1.0
|
||||||
- 1.2
|
- 1.2
|
||||||
- 1.3
|
- 1.3
|
||||||
- nightly
|
- nightly
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
# This file is machine-generated - editing it directly is not advised
|
|
||||||
|
|
||||||
[[AbstractFFTs]]
|
[[AbstractFFTs]]
|
||||||
deps = ["LinearAlgebra"]
|
deps = ["LinearAlgebra"]
|
||||||
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
|
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
|
||||||
|
@ -38,29 +36,23 @@ git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
|
||||||
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
|
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
|
|
||||||
[[CSTParser]]
|
|
||||||
deps = ["Tokenize"]
|
|
||||||
git-tree-sha1 = "99dda94f5af21a4565dc2b97edf6a95485f116c3"
|
|
||||||
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
|
||||||
version = "1.0.0"
|
|
||||||
|
|
||||||
[[CUDAapi]]
|
[[CUDAapi]]
|
||||||
deps = ["Libdl", "Logging"]
|
deps = ["Libdl", "Logging"]
|
||||||
git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b"
|
git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3"
|
||||||
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||||
version = "1.2.0"
|
version = "2.0.0"
|
||||||
|
|
||||||
[[CUDAdrv]]
|
[[CUDAdrv]]
|
||||||
deps = ["CEnum", "Printf"]
|
deps = ["CEnum", "CUDAapi", "Printf"]
|
||||||
git-tree-sha1 = "96eabc95ebb83e361311330ffb574a3e2df73251"
|
git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921"
|
||||||
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||||
version = "4.0.2"
|
version = "4.0.4"
|
||||||
|
|
||||||
[[CUDAnative]]
|
[[CUDAnative]]
|
||||||
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
|
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
|
||||||
git-tree-sha1 = "dd642afe5fd6633663a8c3d42f3b7638f2210b79"
|
git-tree-sha1 = "93f6c917ab2a9b5bb54f8f738f4ec1a6693cb716"
|
||||||
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
||||||
version = "2.5.3"
|
version = "2.5.5"
|
||||||
|
|
||||||
[[CodecZlib]]
|
[[CodecZlib]]
|
||||||
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
|
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
|
||||||
|
@ -98,17 +90,13 @@ git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
|
||||||
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
|
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
|
||||||
[[Crayons]]
|
|
||||||
deps = ["Test"]
|
|
||||||
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
|
||||||
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
|
||||||
version = "4.0.0"
|
|
||||||
|
|
||||||
[[CuArrays]]
|
[[CuArrays]]
|
||||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||||
git-tree-sha1 = "bc94d6cb335d418088f12641751aab63ff56509d"
|
git-tree-sha1 = "7e00178b18672ee2cf37244ac2a273b6b0701b04"
|
||||||
|
repo-rev = "master"
|
||||||
|
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
||||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
version = "1.4.2"
|
version = "1.4.7"
|
||||||
|
|
||||||
[[DataAPI]]
|
[[DataAPI]]
|
||||||
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
|
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
|
||||||
|
@ -117,9 +105,9 @@ version = "1.1.0"
|
||||||
|
|
||||||
[[DataStructures]]
|
[[DataStructures]]
|
||||||
deps = ["InteractiveUtils", "OrderedCollections"]
|
deps = ["InteractiveUtils", "OrderedCollections"]
|
||||||
git-tree-sha1 = "1fe8fad5fc84686dcbc674aa255bc867a64f8132"
|
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
|
||||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||||
version = "0.17.5"
|
version = "0.17.6"
|
||||||
|
|
||||||
[[Dates]]
|
[[Dates]]
|
||||||
deps = ["Printf"]
|
deps = ["Printf"]
|
||||||
|
@ -136,13 +124,13 @@ uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||||
version = "0.0.4"
|
version = "0.0.4"
|
||||||
|
|
||||||
[[DiffRules]]
|
[[DiffRules]]
|
||||||
deps = ["Random", "Test"]
|
deps = ["NaNMath", "Random", "SpecialFunctions"]
|
||||||
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
git-tree-sha1 = "f734b5f6bc9c909027ef99f6d91d5d9e4b111eed"
|
||||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||||
version = "0.0.10"
|
version = "0.1.0"
|
||||||
|
|
||||||
[[Distributed]]
|
[[Distributed]]
|
||||||
deps = ["Random", "Serialization", "Sockets"]
|
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||||
|
|
||||||
[[FFTW]]
|
[[FFTW]]
|
||||||
|
@ -153,9 +141,9 @@ version = "1.0.1"
|
||||||
|
|
||||||
[[FillArrays]]
|
[[FillArrays]]
|
||||||
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
||||||
git-tree-sha1 = "6827a8f73ff12707f209c920d204238a16892b55"
|
git-tree-sha1 = "1a9fe4e1323f38de0ba4da49eafd15b25ec62298"
|
||||||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
||||||
version = "0.8.0"
|
version = "0.8.2"
|
||||||
|
|
||||||
[[FixedPointNumbers]]
|
[[FixedPointNumbers]]
|
||||||
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
|
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
|
||||||
|
@ -164,9 +152,9 @@ version = "0.6.1"
|
||||||
|
|
||||||
[[ForwardDiff]]
|
[[ForwardDiff]]
|
||||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
|
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
|
||||||
git-tree-sha1 = "adf88d6da1f0294058f38295becf8807986bb7d0"
|
git-tree-sha1 = "da46ac97b17793eba44ff366dc6cb70f1238a738"
|
||||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||||
version = "0.10.5"
|
version = "0.10.7"
|
||||||
|
|
||||||
[[GPUArrays]]
|
[[GPUArrays]]
|
||||||
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
|
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
|
||||||
|
@ -181,7 +169,7 @@ uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
|
|
||||||
[[InteractiveUtils]]
|
[[InteractiveUtils]]
|
||||||
deps = ["Markdown"]
|
deps = ["LinearAlgebra", "Markdown"]
|
||||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
|
|
||||||
[[JSON]]
|
[[JSON]]
|
||||||
|
@ -216,10 +204,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||||
|
|
||||||
[[MacroTools]]
|
[[MacroTools]]
|
||||||
deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
|
deps = ["Compat", "DataStructures", "Test"]
|
||||||
git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
|
git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06"
|
||||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||||
version = "0.5.1"
|
version = "0.5.2"
|
||||||
|
|
||||||
[[Markdown]]
|
[[Markdown]]
|
||||||
deps = ["Base64"]
|
deps = ["Base64"]
|
||||||
|
@ -247,10 +235,9 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
version = "0.6.0"
|
version = "0.6.0"
|
||||||
|
|
||||||
[[NaNMath]]
|
[[NaNMath]]
|
||||||
deps = ["Compat"]
|
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
|
||||||
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
|
|
||||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||||
version = "0.3.2"
|
version = "0.3.3"
|
||||||
|
|
||||||
[[OrderedCollections]]
|
[[OrderedCollections]]
|
||||||
deps = ["Random", "Serialization", "Test"]
|
deps = ["Random", "Serialization", "Test"]
|
||||||
|
@ -260,12 +247,12 @@ version = "1.1.0"
|
||||||
|
|
||||||
[[Parsers]]
|
[[Parsers]]
|
||||||
deps = ["Dates", "Test"]
|
deps = ["Dates", "Test"]
|
||||||
git-tree-sha1 = "c56ecb484f286639f161e712b8311f5ab77e8d32"
|
git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556"
|
||||||
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
||||||
version = "0.3.8"
|
version = "0.3.10"
|
||||||
|
|
||||||
[[Pkg]]
|
[[Pkg]]
|
||||||
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||||
|
|
||||||
[[Printf]]
|
[[Printf]]
|
||||||
|
@ -327,9 +314,9 @@ version = "0.8.0"
|
||||||
|
|
||||||
[[StaticArrays]]
|
[[StaticArrays]]
|
||||||
deps = ["LinearAlgebra", "Random", "Statistics"]
|
deps = ["LinearAlgebra", "Random", "Statistics"]
|
||||||
git-tree-sha1 = "1e9c5d89cba8047d518f1ffef432906ef1a3e8bd"
|
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
|
||||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
version = "0.12.0"
|
version = "0.12.1"
|
||||||
|
|
||||||
[[Statistics]]
|
[[Statistics]]
|
||||||
deps = ["LinearAlgebra", "SparseArrays"]
|
deps = ["LinearAlgebra", "SparseArrays"]
|
||||||
|
@ -346,15 +333,10 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
|
||||||
[[TimerOutputs]]
|
[[TimerOutputs]]
|
||||||
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
deps = ["Printf"]
|
||||||
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03"
|
||||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||||
version = "0.5.0"
|
version = "0.5.3"
|
||||||
|
|
||||||
[[Tokenize]]
|
|
||||||
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
|
|
||||||
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
|
||||||
version = "0.5.6"
|
|
||||||
|
|
||||||
[[TranscodingStreams]]
|
[[TranscodingStreams]]
|
||||||
deps = ["Random", "Test"]
|
deps = ["Random", "Test"]
|
||||||
|
@ -369,7 +351,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
|
||||||
[[UUIDs]]
|
[[UUIDs]]
|
||||||
deps = ["Random", "SHA"]
|
deps = ["Random"]
|
||||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||||
|
|
||||||
[[Unicode]]
|
[[Unicode]]
|
||||||
|
@ -389,9 +371,9 @@ version = "0.8.3"
|
||||||
|
|
||||||
[[Zygote]]
|
[[Zygote]]
|
||||||
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||||
git-tree-sha1 = "b2e42a21dc3d1ecd3cbe8c83a454ca56fbf423c4"
|
git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b"
|
||||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
version = "0.4.0"
|
version = "0.4.1"
|
||||||
|
|
||||||
[[ZygoteRules]]
|
[[ZygoteRules]]
|
||||||
deps = ["MacroTools"]
|
deps = ["MacroTools"]
|
||||||
|
|
13
NEWS.md
13
NEWS.md
|
@ -1,3 +1,16 @@
|
||||||
|
# v0.10.0
|
||||||
|
* The default AD engine has switched from [Tracker to Zygote.jl](https://github.com/FluxML/Flux.jl/pull/669)
|
||||||
|
- The dependency on Tracker.jl has been removed.
|
||||||
|
- This means Flux now does not depend on using a specialised `TrackedArray` type, and can be used with normal Array implementations directly.
|
||||||
|
- Tracker compatibility is maintained in most common cases, but Zygote will be the preferred AD backend for Flux from now on.
|
||||||
|
* The CUDNN wrappers have been [moved from Flux into CuArrays](https://github.com/FluxML/Flux.jl/pull/874), to allow for better supporting the CUDA backend, and improve user experience, not to mention making Flux lean.
|
||||||
|
* `*crossentropy` functions now [work as expected with CuArrays](https://github.com/FluxML/Flux.jl/pull/926). [PR for binarycrossentropy](https://github.com/FluxML/Flux.jl/pull/940).
|
||||||
|
* Added [clearer docs](https://github.com/FluxML/Flux.jl/pull/904) around training and the Optimiser interface.
|
||||||
|
* [Layer initialisations](https://github.com/FluxML/Flux.jl/pull/937) have been improved with a clearer API on how to extend it for other purposes.
|
||||||
|
* [Better messaging around CUDA availability](https://github.com/FluxML/Flux.jl/pull/924), with hooks to initialize the GPU as default where possible.
|
||||||
|
* `@treelike` has been formalised as a [functor](https://github.com/FluxML/Flux.jl/pull/865), with an effective deprecation.
|
||||||
|
* `testmode!` is deprecated in favour of [istraining](https://github.com/FluxML/Flux.jl/pull/669)
|
||||||
|
|
||||||
# v0.9.0
|
# v0.9.0
|
||||||
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
|
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
|
||||||
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
|
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
|
||||||
|
|
15
Project.toml
15
Project.toml
|
@ -1,11 +1,10 @@
|
||||||
name = "Flux"
|
name = "Flux"
|
||||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
version = "0.9.0"
|
version = "0.10.0"
|
||||||
|
|
||||||
[deps]
|
[deps]
|
||||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||||
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
|
||||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
|
@ -25,9 +24,17 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
|
|
||||||
[compat]
|
[compat]
|
||||||
CUDAdrv = "4.0.1"
|
AbstractTrees = "0.2"
|
||||||
CuArrays = "1.4.2"
|
Adapt = "1"
|
||||||
|
CodecZlib = "0.5, 0.6"
|
||||||
|
Colors = "0.8, 0.9"
|
||||||
|
CuArrays = "1.4.3"
|
||||||
|
Juno = "0.5, 0.6, 0.7"
|
||||||
|
MacroTools = "0.3, 0.4, 0.5"
|
||||||
NNlib = "0.6"
|
NNlib = "0.6"
|
||||||
|
Reexport = "0.2"
|
||||||
|
StatsBase = "0"
|
||||||
|
ZipFile = "0.7, 0.8"
|
||||||
Zygote = "0.4"
|
Zygote = "0.4"
|
||||||
julia = "1"
|
julia = "1"
|
||||||
|
|
||||||
|
|
88
README.md
88
README.md
|
@ -7,93 +7,9 @@
|
||||||
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> Pkg.add("Flux")
|
] add Flux
|
||||||
```
|
```
|
||||||
|
|
||||||
See the [documentation](https://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
See the [documentation](https://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||||
|
|
||||||
If you use Flux in research, please cite the following paper:
|
If you use Flux in research, please see [our papers](CITATION.bib) for appropriate citations.
|
||||||
|
|
||||||
```
|
|
||||||
@article{innes:2018,
|
|
||||||
author = {Mike Innes},
|
|
||||||
title = {Flux: Elegant Machine Learning with Julia},
|
|
||||||
journal = {Journal of Open Source Software},
|
|
||||||
year = {2018},
|
|
||||||
doi = {10.21105/joss.00602},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
Flux has powerful high-level features, and common architectures can be defined in a few lines.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
model = Chain(
|
|
||||||
Dense(768, 128, σ),
|
|
||||||
LSTM(128, 256),
|
|
||||||
LSTM(256, 128),
|
|
||||||
Dense(128, 10),
|
|
||||||
softmax)
|
|
||||||
|
|
||||||
loss(x, y) = crossentropy(model(x), y)
|
|
||||||
|
|
||||||
Flux.train!(loss, params(model), data, ADAM(...))
|
|
||||||
```
|
|
||||||
|
|
||||||
Yet you can easily strip away the layers, and directly write the mathematics for your problem. Flux will seamlessly take gradients of any Julia code, so your model looks just like the paper.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
W = param(randn(2, 10))
|
|
||||||
b = param(randn(2))
|
|
||||||
|
|
||||||
y(x) = σ.(W * x .+ b)
|
|
||||||
```
|
|
||||||
|
|
||||||
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels with [CUDAnative](https://github.com/JuliaGPU/CUDAnative.jl)! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
function gpu_add(a, b, c)
|
|
||||||
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
|
|
||||||
c[i] = a[i] + b[i]
|
|
||||||
return nothing
|
|
||||||
end
|
|
||||||
```
|
|
||||||
|
|
||||||
Unusual architectures are no problem in Flux, as you can use all the loops, control flow and even macros that you're used to. Here's a Tree RNN in 4 lines.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
tree() = rand() < 0.5 ? rand(10) : (tree(), tree()) # dummy data
|
|
||||||
|
|
||||||
shrink = Dense(20, 10)
|
|
||||||
combine(a, b) = shrink([a; b])
|
|
||||||
|
|
||||||
model(x) = x
|
|
||||||
model(x::Tuple) = combine(model(x[1]), model(x[2]))
|
|
||||||
|
|
||||||
model(tree()) # Sample output
|
|
||||||
```
|
|
||||||
|
|
||||||
Despite this flexibility, Julia's advanced compiler lets us do some powerful optimisations. For example, this definition of `sigmoid` automatically gets fused into a *single* GPU kernel – so it's really fast.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
sigmoid(xs) = 1 ./ (1 .+ exp.(.-xs))
|
|
||||||
```
|
|
||||||
|
|
||||||
Similarly, Flux is the first dynamic framework to support [compiling to the browser](https://fluxml.github.io/experiments/) and model import via [formats like ONNX](https://github.com/FluxML/ONNX.jl/), both of which are thinly-veiled compiler problems.
|
|
||||||
|
|
||||||
For more on our philosophy on machine learning, check out our article [On Machine Learning & Programming Languages](https://julialang.org/blog/2017/12/ml&pl).
|
|
||||||
|
|
||||||
## Contributing & Help
|
|
||||||
|
|
||||||
For general questions and help, check out Julia's [community forum](https://discourse.julialang.org/c/domain/ML).
|
|
||||||
|
|
||||||
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
|
|
||||||
|
|
||||||
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
|
|
||||||
|
|
||||||
## Related Packages
|
|
||||||
|
|
||||||
Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
|
|
||||||
|
|
||||||
[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
# Training
|
# Training
|
||||||
|
|
||||||
To actually train a model we need three things:
|
To actually train a model we need four things:
|
||||||
|
|
||||||
* A *objective function*, that evaluates how well a model is doing given some input data.
|
* A *objective function*, that evaluates how well a model is doing given some input data.
|
||||||
|
* The trainable parameters of the model.
|
||||||
* A collection of data points that will be provided to the objective function.
|
* A collection of data points that will be provided to the objective function.
|
||||||
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
|
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
|
||||||
|
|
||||||
|
@ -32,6 +33,14 @@ Flux.train!(loss, ps, data, opt)
|
||||||
|
|
||||||
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||||
|
|
||||||
|
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
|
||||||
|
|
||||||
|
## Model parameters
|
||||||
|
|
||||||
|
The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the [basics](../models/basics.md) section it is explained how to create models with such parameters. The second argument of the function `Flux.train!` must be an object containing those parameters, which can be obtained from a model `m` as `params(m)`.
|
||||||
|
|
||||||
|
Such an object contains a reference to the model's parameters, not a copy, such that after their training, the model behaves according to their updated values.
|
||||||
|
|
||||||
## Datasets
|
## Datasets
|
||||||
|
|
||||||
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
|
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
|
||||||
|
|
21
src/Flux.jl
21
src/Flux.jl
|
@ -6,7 +6,7 @@ using Base: tail
|
||||||
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using Zygote: Params, @adjoint, gradient, pullback
|
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||||
export gradient
|
export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||||
|
@ -21,8 +21,7 @@ export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||||
|
|
||||||
|
|
||||||
ENV["CUDA_INIT_SILENT"] = true
|
using CuArrays
|
||||||
using CUDAdrv, CuArrays
|
|
||||||
const use_cuda = Ref(false)
|
const use_cuda = Ref(false)
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
@ -40,12 +39,14 @@ include("data/Data.jl")
|
||||||
include("deprecations.jl")
|
include("deprecations.jl")
|
||||||
|
|
||||||
function __init__()
|
function __init__()
|
||||||
if !CUDAdrv.functional()
|
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
|
||||||
@warn "CUDA available, but CUDAdrv.jl failed to load"
|
|
||||||
elseif length(devices()) == 0
|
# we don't want to include the CUDA module when precompiling,
|
||||||
@warn "CUDA available, but no GPU detected"
|
# or we could end up replacing it at run time (triggering a warning)
|
||||||
elseif !CuArrays.functional()
|
precompiling && return
|
||||||
@warn "CUDA GPU available, but CuArrays.jl failed to load"
|
|
||||||
|
if !CuArrays.functional()
|
||||||
|
# nothing to do here, and either CuArrays or one of its dependencies will have warned
|
||||||
else
|
else
|
||||||
use_cuda[] = true
|
use_cuda[] = true
|
||||||
|
|
||||||
|
@ -54,7 +55,7 @@ function __init__()
|
||||||
if CuArrays.has_cudnn()
|
if CuArrays.has_cudnn()
|
||||||
include(joinpath(@__DIR__, "cuda/cuda.jl"))
|
include(joinpath(@__DIR__, "cuda/cuda.jl"))
|
||||||
else
|
else
|
||||||
@warn "CUDA GPU available, but CuArrays.jl did not find libcudnn. Some functionality will not be available."
|
@warn "CuArrays.jl did not find libcudnn. Some functionality will not be available."
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -44,19 +44,23 @@ end
|
||||||
# it might be replaced in the future for better performance
|
# it might be replaced in the future for better performance
|
||||||
# see issue https://github.com/FluxML/Flux.jl/issues/702
|
# see issue https://github.com/FluxML/Flux.jl/issues/702
|
||||||
# Johnny Chen -- @johnnychen94
|
# Johnny Chen -- @johnnychen94
|
||||||
|
# only slightly changed to better handle interaction with Zygote @dsweber2
|
||||||
"""
|
"""
|
||||||
activations(c::Chain, input)
|
activations(c::Chain, input)
|
||||||
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||||
"""
|
"""
|
||||||
function activations(c::Chain, input)
|
function activations(c::Chain, input)
|
||||||
rst = []
|
extraChain(c.layers, input)
|
||||||
for l in c
|
|
||||||
x = get(rst, length(rst), input)
|
|
||||||
push!(rst, l(x))
|
|
||||||
end
|
|
||||||
return rst
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function extraChain(fs::Tuple, x)
|
||||||
|
res = first(fs)(x)
|
||||||
|
return (res, extraChain(Base.tail(fs), res)...)
|
||||||
|
end
|
||||||
|
|
||||||
|
extraChain(::Tuple{}, x) = ()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dense(in::Integer, out::Integer, σ = identity)
|
Dense(in::Integer, out::Integer, σ = identity)
|
||||||
|
|
|
@ -144,6 +144,9 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
|
||||||
|
@nograd conv_transpose_dims
|
||||||
|
|
||||||
function (c::ConvTranspose)(x::AbstractArray)
|
function (c::ConvTranspose)(x::AbstractArray)
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
using CuArrays
|
||||||
using NNlib: logsoftmax, logσ
|
using NNlib: logsoftmax, logσ
|
||||||
|
|
||||||
# Cost functions
|
# Cost functions
|
||||||
|
@ -35,6 +36,9 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica
|
||||||
"""
|
"""
|
||||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||||
|
|
||||||
|
# Re-definition to fix interaction with CuArrays.
|
||||||
|
CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logitbinarycrossentropy(logŷ, y)
|
logitbinarycrossentropy(logŷ, y)
|
||||||
|
|
||||||
|
@ -49,6 +53,9 @@ but it is more numerically stable.
|
||||||
"""
|
"""
|
||||||
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||||
|
|
||||||
|
# Re-definition to fix interaction with CuArrays.
|
||||||
|
CuArrays.@cufunc logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
normalise(x::AbstractArray; dims=1)
|
normalise(x::AbstractArray; dims=1)
|
||||||
|
|
||||||
|
|
|
@ -283,7 +283,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||||
|
|
||||||
function apply!(o::ADAGrad, x, Δ)
|
function apply!(o::ADAGrad, x, Δ)
|
||||||
η = o.eta
|
η = o.eta
|
||||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
acc = get!(o.acc, x, fill!(zero(x), ϵ))::typeof(x)
|
||||||
@. acc += Δ^2
|
@. acc += Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
|
@ -349,10 +349,10 @@ AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||||
|
|
||||||
function apply!(o::AMSGrad, x, Δ)
|
function apply!(o::AMSGrad, x, Δ)
|
||||||
η, β = o.eta, o.beta
|
η, β = o.eta, o.beta
|
||||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
mt, vt, v̂t = get!(o.state, x, (fill!(zero(x), ϵ), fill!(zero(x), ϵ), fill!(zero(x), ϵ)))
|
||||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
||||||
@. v̂t = max.(v̂t, vt)
|
@. v̂t = max(v̂t, vt)
|
||||||
@. Δ = η * mt / (√v̂t + ϵ)
|
@. Δ = η * mt / (√v̂t + ϵ)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
# Arrays
|
# Arrays
|
||||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
|
nfan() = 1, 1 #fan_in, fan_out
|
||||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
|
nfan(n) = 1, n #A vector is treated as a n×1 matrix
|
||||||
|
nfan(n_out, n_in) = n_in, n_out #In case of Dense kernels: arranged as matrices
|
||||||
|
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) #In case of convolution kernels
|
||||||
|
|
||||||
|
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
||||||
|
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
||||||
|
|
||||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||||
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||||
|
|
|
@ -31,6 +31,11 @@ cx = gpu(x)
|
||||||
@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0)
|
@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0)
|
||||||
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
|
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
|
||||||
|
|
||||||
|
x = [-1.1491, 0.8619, 0.3127]
|
||||||
|
y = [1, 1, 0.]
|
||||||
|
@test Flux.binarycrossentropy.(σ.(x),y) ≈ Flux.binarycrossentropy.(cu(σ.(x)),cu(y))
|
||||||
|
@test Flux.logitbinarycrossentropy.(x,y) ≈ Flux.logitbinarycrossentropy.(cu(x),cu(y))
|
||||||
|
|
||||||
xs = rand(5, 5)
|
xs = rand(5, 5)
|
||||||
ys = Flux.onehotbatch(1:5,1:5)
|
ys = Flux.onehotbatch(1:5,1:5)
|
||||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||||
|
|
|
@ -4,11 +4,13 @@ import Flux: activations
|
||||||
@testset "basic" begin
|
@testset "basic" begin
|
||||||
@testset "helpers" begin
|
@testset "helpers" begin
|
||||||
@testset "activations" begin
|
@testset "activations" begin
|
||||||
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x))
|
||||||
x = rand(10)
|
x = randn(10)
|
||||||
@test activations(Chain(), x) == []
|
@test activations(dummy_model, x)[1] == x.^2
|
||||||
@test activations(dummy_model, x)[1] == dummy_model[1](x)
|
@test activations(dummy_model, x)[2] == (x.^2 .- 3)
|
||||||
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
|
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3)
|
||||||
|
|
||||||
|
@test activations(Chain(), x) == ()
|
||||||
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
|
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -19,6 +21,12 @@ import Flux: activations
|
||||||
# numeric test should be put into testset of corresponding layer
|
# numeric test should be put into testset of corresponding layer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Activations" begin
|
||||||
|
c = Chain(Dense(3,5,relu), Dense(5,1,relu))
|
||||||
|
X = Float32.([1.0; 1.0; 1.0])
|
||||||
|
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c))
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Dense" begin
|
@testset "Dense" begin
|
||||||
@test length(Dense(10, 5)(randn(10))) == 5
|
@test length(Dense(10, 5)(randn(10))) == 5
|
||||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
using Flux, Test
|
using Flux, Test
|
||||||
using Flux: maxpool, meanpool
|
using Flux: maxpool, meanpool
|
||||||
|
using Flux: gradient
|
||||||
|
|
||||||
@testset "Pooling" begin
|
@testset "Pooling" begin
|
||||||
x = randn(Float32, 10, 10, 3, 2)
|
x = randn(Float32, 10, 10, 3, 2)
|
||||||
|
@ -54,6 +55,10 @@ end
|
||||||
y = Conv((3,3), 1 => 1)(x)
|
y = Conv((3,3), 1 => 1)(x)
|
||||||
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||||
@test size(x_hat) == size(x)
|
@test size(x_hat) == size(x)
|
||||||
|
|
||||||
|
m = ConvTranspose((3,3), 1=>1)
|
||||||
|
# Test that the gradient call does not throw: #900
|
||||||
|
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "CrossCor" begin
|
@testset "CrossCor" begin
|
||||||
|
|
|
@ -191,6 +191,7 @@ end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if VERSION >= v"1.1"
|
||||||
@testset "GroupNorm" begin
|
@testset "GroupNorm" begin
|
||||||
# begin tests
|
# begin tests
|
||||||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
||||||
|
@ -289,5 +290,5 @@ end
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test BN(x) ≈ GN(x)
|
@test BN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
using Flux
|
using Flux
|
||||||
using Flux: throttle, glorot_uniform, glorot_normal, stack, unstack
|
using Flux: throttle, nfan, glorot_uniform, glorot_normal, stack, unstack
|
||||||
using StatsBase: std
|
using StatsBase: var
|
||||||
using Random
|
using Random
|
||||||
using Test
|
using Test
|
||||||
|
|
||||||
|
@ -56,18 +56,26 @@ end
|
||||||
# Set random seed so that these tests don't fail randomly
|
# Set random seed so that these tests don't fail randomly
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
|
||||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
@testset "Fan in/out" begin
|
||||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
@test nfan() == (1, 1) #For a constant
|
||||||
for (n_in, n_out) in [(100, 100), (100, 400)]
|
@test nfan(100) == (1, 100) #For vector
|
||||||
v = glorot_uniform(n_in, n_out)
|
@test nfan(100, 200) == (200, 100) #For Dense layer
|
||||||
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
@test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer
|
||||||
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
@test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer
|
||||||
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
|
||||||
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
end
|
||||||
|
|
||||||
v = glorot_normal(n_in, n_out)
|
@testset "glorot" begin
|
||||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
# glorot_uniform and glorot_normal should both yield a kernel with
|
||||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
# variance ≈ 2/(fan_in + fan_out)
|
||||||
|
for dims ∈ [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
|
||||||
|
for init ∈ [glorot_uniform, glorot_normal]
|
||||||
|
v = init(dims...)
|
||||||
|
fan_in, fan_out = nfan(dims...)
|
||||||
|
σ2 = 2 / (fan_in + fan_out)
|
||||||
|
@test 0.9σ2 < var(v) < 1.1σ2
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue