Merge pull request #669 from FluxML/zygote

using Zygote
This commit is contained in:
Mike J Innes 2019-09-11 16:22:26 +01:00 committed by GitHub
commit bdeb9c6d58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 500 additions and 1239 deletions

View File

@ -6,7 +6,7 @@ os:
# - osx # - osx
julia: julia:
- 1.0 - 1.1
- nightly - nightly
matrix: matrix:

View File

@ -174,6 +174,12 @@ git-tree-sha1 = "dd169c636d1d3656a9faca772f5bd7c226a61254"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "1.0.1" version = "1.0.1"
[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "e23faa71b8f54c3fdc99b230b9c2906cafdddca5"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.2.3"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@ -226,10 +232,9 @@ uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0" version = "0.5.0"
[[Missings]] [[Missings]]
deps = ["SparseArrays", "Test"] git-tree-sha1 = "29858ce6c8ae629cf2d733bffa329619a1c843d0"
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.1" version = "0.4.2"
[[Mmap]] [[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@ -254,9 +259,9 @@ version = "1.1.0"
[[Parsers]] [[Parsers]]
deps = ["Dates", "Test"] deps = ["Dates", "Test"]
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9" git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.6" version = "0.3.7"
[[Pkg]] [[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@ -314,10 +319,10 @@ deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]] [[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] deps = ["BinDeps", "BinaryProvider", "Libdl"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b" uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2" version = "0.8.0"
[[StaticArrays]] [[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"] deps = ["LinearAlgebra", "Random", "Statistics"]
@ -350,12 +355,6 @@ git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.6" version = "0.5.6"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "1aa443d3b4bfa91a8aec32f169a479cb87309910"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.3"
[[TranscodingStreams]] [[TranscodingStreams]]
deps = ["Random", "Test"] deps = ["Random", "Test"]
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c" git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
@ -386,3 +385,17 @@ deps = ["BinaryProvider", "Libdl", "Printf"]
git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d" git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.3" version = "0.8.3"
[[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "9186cb0b3b59219e4aba0840614d6a9d7282012e"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.3.4"
[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.1.0"

View File

@ -21,18 +21,20 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat] [compat]
CUDAapi = "1.1" CUDAapi = "1.1"
CuArrays = "1.2" CuArrays = "1.2"
NNlib = "0.6" NNlib = "0.6"
Tracker = "0.2" Zygote = "0.3"
julia = "0.7, 1" julia = "1.1"
[extras] [extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets] [targets]
test = ["Test"] test = ["Test", "Documenter"]

13
REQUIRE
View File

@ -1,13 +0,0 @@
julia 1.0
Juno
MacroTools 0.3.3
NNlib
Requires
Adapt 0.4
CodecZlib
Colors
ZipFile
AbstractTrees
Reexport
StatsBase
Tracker

View File

@ -1,205 +1,56 @@
# This file is machine-generated - editing it directly is not advised # This file is machine-generated - editing it directly is not advised
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.2"
[[Base64]] [[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.2"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
[[Dates]] [[Dates]]
deps = ["Printf"] deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"
[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"
[[Distributed]] [[Distributed]]
deps = ["Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]] [[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"] deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "4d30e889c9f106a51ffa4791a88ffd4765bf20c3" git-tree-sha1 = "0513f1a8991e9d83255e0140aace0d0fc4486600"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.7.0" version = "0.8.0"
[[Documenter]] [[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"] deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0" git-tree-sha1 = "c61d6eedbc3c4323c08b64af12d29c8ee0fcbb5f"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.22.3" version = "0.23.2"
[[FixedPointNumbers]]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"]
path = ".."
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.8.2+"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[JSON]] [[JSON]]
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"] deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa" git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.20.0" version = "0.21.0"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.7.0"
[[LibGit2]] [[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]] [[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.0"
[[Markdown]] [[Markdown]]
deps = ["Base64"] deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.0"
[[Mmap]] [[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]] [[Parsers]]
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] deps = ["Dates", "Test"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.6.0" version = "0.3.6"
[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"
[[Pkg]] [[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@ -209,10 +60,6 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
deps = ["Unicode"] deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]] [[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"] deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
@ -221,106 +68,22 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"] deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
[[SHA]] [[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]] [[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]] [[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc" uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.3"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.30.0"
[[Test]] [[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.0"
[[TranscodingStreams]]
deps = ["Random", "Test"]
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.4"
[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]] [[UUIDs]]
deps = ["Random", "SHA"] deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]] [[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.1"

View File

@ -1,4 +1,2 @@
[deps] [deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

View File

@ -1,12 +1,13 @@
using Pkg;
Pkg.activate(joinpath(@__DIR__, "..")); Pkg.instantiate()
Pkg.activate(); Pkg.instantiate()
pushfirst!(LOAD_PATH, joinpath(@__DIR__, ".."))
using Documenter, Flux, NNlib using Documenter, Flux, NNlib
makedocs(modules=[Flux, NNlib], makedocs(modules=[Flux, NNlib],
doctest = true,
analytics = "UA-36890222-9",
sitename = "Flux", sitename = "Flux",
# Uncomment below for local build
#format = Documenter.HTML(prettyurls = false),
assets = ["assets/flux.css"],
pages = ["Home" => "index.md", pages = ["Home" => "index.md",
"Building Models" => "Building Models" =>
["Basics" => "models/basics.md", ["Basics" => "models/basics.md",
@ -20,8 +21,9 @@ makedocs(modules=[Flux, NNlib],
"GPU Support" => "gpu.md", "GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md", "Saving & Loading" => "saving.md",
"Performance Tips" => "performance.md", "Performance Tips" => "performance.md",
"Internals" => "Community" => "community.md"],
["Backpropagation" => "internals/tracker.md"], format = Documenter.HTML(assets = ["assets/flux.css"],
"Community" => "community.md"]) analytics = "UA-36890222-9",
prettyurls = haskey(ENV, "CI")))
deploydocs(repo = "github.com/FluxML/Flux.jl.git") deploydocs(repo = "github.com/FluxML/Flux.jl.git")

View File

@ -1,5 +1,5 @@
# Community # Community
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby). If you have questions or issues we'll try to help you out. All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), or the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning). If you have questions or issues we'll try to help you out.
If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started. If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started.

View File

@ -1,14 +1,6 @@
# GPU Support # GPU Support
## Installation NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CuArrays](https://github.com/JuliaGPU/CuArrays.jl) readme.
To get GPU support for NVIDIA graphics cards, you need to install `CuArrays.jl`
**Steps needed**
1. Install [NVIDIA toolkit](https://developer.nvidia.com/cuda-downloads)
2. Install [NVIDIA cuDNN library](https://developer.nvidia.com/cudnn)
3. In Julia's terminal run `]add CuArrays`
## GPU Usage ## GPU Usage

View File

@ -1,184 +0,0 @@
# Flux.Tracker
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
```julia
julia> using Flux.Tracker
```
Here we discuss some more advanced uses of this module, as well as covering its internals.
## Taking Gradients
In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function.
```julia
using Flux.Tracker
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
```
`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`.
```julia
using Flux.Tracker: forward
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
back(1) # (3.0 (tracked), 2.0 (tracked))
```
The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
```julia
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
julia> back([1,1,1])
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
```
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
```julia
a, b = param(2), param(3)
c = a*b # 6.0 (tracked)
Tracker.back!(c)
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
```
## Tracked Arrays
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
```julia
julia> W = param([1 2; 3 4])
Tracked 2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> x = param([5, 6])
Tracked 2-element Array{Float64,1}:
5.0
6.0
julia> y = W*x
Tracked 2-element Array{Float64,1}:
17.0
39.0
```
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
```julia
julia> Tracker.back!(y, [1, -1])
julia> W.grad
2×2 Array{Float64,2}:
5.0 6.0
-5.0 -6.0
julia> x.grad
2-element Array{Float64,1}:
-2.0
-2.0
```
You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
```julia
minus(a, b) = a - b
```
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
using Flux.Tracker: TrackedArray, track, @grad
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia
@grad function minus(a, b)
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
end
```
This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress).
Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this.
```julia
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
```
We can then calculate the first derivative of `minus` as follows:
```julia
a = param([1,2,3])
b = param([3,2,1])
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
Tracker.back!(c, 1)
Tracker.grad(a) # [1.00, 1.00, 1.00]
Tracker.grad(b) # [-1.00, -1.00, -1.00]
```
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
```
## Tracked Internals
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
```julia
julia> x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
```
The `Tracker` stores the gradient of a given object, which we've seen before.
```julia
julia> x.tracker.grad
2-element Array{Float64,1}:
-2.0
-2.0
```
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
```julia
julia> Tracker.Call(+, 1, 2)
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
```
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
```julia
julia> y.tracker.f
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
```
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
```julia
Tracker.back(*, [1, -1], W, x)
```
which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.

View File

@ -5,55 +5,56 @@
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.) Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
```jldoctest basics ```jldoctest basics
julia> using Flux.Tracker julia> using Flux
julia> f(x) = 3x^2 + 2x + 1; julia> f(x) = 3x^2 + 2x + 1;
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2 julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
julia> df(2) julia> df(2)
14.0 (tracked) 14
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6 julia> d2f(x) = gradient(df, x)[1]; # d²f/dx² = 6
julia> d2f(2) julia> d2f(2)
6.0 (tracked) 6
``` ```
(We'll learn more about why these numbers show up as `(tracked)` below.) When a function has many parameters, we can get gradients of each one at the same time:
When a function has many parameters, we can pass them all in explicitly:
```jldoctest basics ```jldoctest basics
julia> f(W, b, x) = W * x + b; julia> f(x, y) = sum((x .- y).^2);
julia> Tracker.gradient(f, 2, 3, 4) julia> gradient(f, [2, 1], [2, 0])
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked)) ([0, 2], [0, -2])
``` ```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once. But machine learning models can have *hundreds* of parameters! To handle this, Flux lets you work with collections of parameters, via `params`. You can get the gradient of all parameters used in a program without explicitly passing them in.
```jldoctest basics ```jldoctest basics
julia> using Flux julia> using Flux
julia> W = param(2) julia> x = [2, 1];
2.0 (tracked)
julia> b = param(3) julia> y = [2, 0];
3.0 (tracked)
julia> f(x) = W * x + b; julia> gs = gradient(params(x, y)) do
f(x, y)
end
Grads(...)
julia> grads = Tracker.gradient(() -> f(4), params(W, b)); julia> gs[x]
2-element Array{Int64,1}:
0
2
julia> grads[W] julia> gs[y]
4.0 (tracked) 2-element Array{Int64,1}:
0
julia> grads[b] -2
1.0 (tracked)
``` ```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate. Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple. This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
@ -76,26 +77,20 @@ x, y = rand(5), rand(2) # Dummy data
loss(x, y) # ~ 3 loss(x, y) # ~ 3
``` ```
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above. To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent.
```julia ```julia
using Flux.Tracker using Flux
W = param(W) gs = gradient(() -> loss(x, y), params(W, b))
b = param(b)
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
``` ```
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent. Now that we have gradients, we can pull them out and update `W` to train the model.
```julia ```julia
using Flux.Tracker: update! W̄ = gs[W]
Δ = gs[W] W .-= 0.1 .* W̄
# Update the parameter and reset the gradient
update!(W, -0.1Δ)
loss(x, y) # ~ 2.5 loss(x, y) # ~ 2.5
``` ```
@ -111,12 +106,12 @@ It's common to create more complex models than the linear regression above. For
```julia ```julia
using Flux using Flux
W1 = param(rand(3, 5)) W1 = rand(3, 5)
b1 = param(rand(3)) b1 = rand(3)
layer1(x) = W1 * x .+ b1 layer1(x) = W1 * x .+ b1
W2 = param(rand(2, 3)) W2 = rand(2, 3)
b2 = param(rand(2)) b2 = rand(2)
layer2(x) = W2 * x .+ b2 layer2(x) = W2 * x .+ b2
model(x) = layer2(σ.(layer1(x))) model(x) = layer2(σ.(layer1(x)))
@ -128,8 +123,8 @@ This works but is fairly unwieldy, with a lot of repetition especially as we
```julia ```julia
function linear(in, out) function linear(in, out)
W = param(randn(out, in)) W = randn(out, in)
b = param(randn(out)) b = randn(out)
x -> W * x .+ b x -> W * x .+ b
end end
@ -150,7 +145,7 @@ struct Affine
end end
Affine(in::Integer, out::Integer) = Affine(in::Integer, out::Integer) =
Affine(param(randn(out, in)), param(randn(out))) Affine(randn(out, in), randn(out))
# Overload call, so the object can be used as a function # Overload call, so the object can be used as a function
(m::Affine)(x) = m.W * x .+ m.b (m::Affine)(x) = m.W * x .+ m.b

View File

@ -59,7 +59,6 @@ swish
These layers don't affect the structure of the network but may improve training times or reduce overfitting. These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs ```@docs
Flux.testmode!
BatchNorm BatchNorm
Dropout Dropout
AlphaDropout AlphaDropout

View File

@ -101,26 +101,4 @@ m = Chain(LSTM(10, 15), Dense(15, 5))
m.(seq) m.(seq)
``` ```
## Truncating Gradients Finally, we can reset the hidden state of the cell back to its initial value using `reset!(m)`.
By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive.
To avoid this we can *truncate* the gradient calculation, forgetting the history.
```julia
truncate!(m)
```
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
In general, when training with recurrent layers in your model, you'll want to call `reset!` or `truncate!` for each loss calculation:
```julia
function loss(x,y)
l = Flux.mse(m(x), y)
Flux.reset!(m)
return l
end
```

View File

@ -15,6 +15,8 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`. We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
```julia ```julia
using LinearAlgebra
penalty() = norm(m.W) + norm(m.b) penalty() = norm(m.W) + norm(m.b)
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty() loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
``` ```
@ -48,15 +50,17 @@ loss(rand(28^2), rand(10))
One can also easily add per-layer regularisation via the `activations` function: One can also easily add per-layer regularisation via the `activations` function:
```julia ```julia
julia> using Flux: activations
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax) julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax) Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> activations(c, rand(10)) julia> activations(c, rand(10))
3-element Array{Any,1}: 3-element Array{Any,1}:
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074]) Float32[0.84682214, 0.6704139, 0.42177814, 0.257832, 0.36255655]
param([0.0330606, -0.456104]) Float32[0.1501253, 0.073269576]
param([0.61991, 0.38009]) Float32[0.5192045, 0.48079553]
julia> sum(norm, ans) julia> sum(norm, ans)
2.639678767773633 (tracked) 2.1166067f0
``` ```

View File

@ -53,7 +53,7 @@ julia> using Flux
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax) julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
julia> weights = Tracker.data.(params(model)); julia> weights = params(model);
julia> using BSON: @save julia> using BSON: @save

View File

@ -3,25 +3,25 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
```julia ```julia
using Flux, Flux.Tracker using Flux
W = param(rand(2, 5)) W = rand(2, 5))
b = param(rand(2)) b = rand(2)
predict(x) = W*x .+ b predict(x) = (W * x) .+ b
loss(x, y) = sum((predict(x) .- y).^2) loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3 l = loss(x, y) # ~ 3
θ = Params([W, b]) θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), θ) grads = gradient(() -> loss(x, y), θ)
``` ```
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia ```julia
using Flux.Tracker: grad, update! using Flux: update!
η = 0.1 # Learning Rate η = 0.1 # Learning Rate
for p in (W, b) for p in (W, b)

View File

@ -3,19 +3,15 @@ module Flux
# Zero Flux Given # Zero Flux Given
using Base: tail using Base: tail
using MacroTools, Juno, Reexport, Statistics, Random using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, forward
export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, SkipConnection, params, mapleaves, cpu, gpu, f32, f64
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib
using Tracker
using Tracker: data
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
@ -49,6 +45,8 @@ include("layers/normalise.jl")
include("data/Data.jl") include("data/Data.jl")
include("deprecations.jl")
if has_cuarrays() if has_cuarrays()
include("cuda/cuda.jl") include("cuda/cuda.jl")
end end

View File

@ -1,13 +1,10 @@
using CuArrays: libcudnn using CuArrays: libcudnn
using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t, using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
import CuArrays.CUDAdrv: CuPtr, CU_NULL import CuArrays.CUDAdrv: CuPtr, CU_NULL
using LinearAlgebra using LinearAlgebra
import ..Flux: data
mutable struct DropoutDesc mutable struct DropoutDesc
ptr::Ptr{Nothing} ptr::Ptr{Nothing}
states::CuVector{UInt8} states::CuVector{UInt8}
@ -198,36 +195,8 @@ end
# Flux Interface # Flux Interface
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)) BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, @adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing)
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)

View File

@ -225,7 +225,6 @@ end
# Interface # Interface
import ..Flux: Flux, relu import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using CuArrays.CUDAnative using CuArrays.CUDAnative
using CuArrays: @cuindex, cudims using CuArrays: @cuindex, cudims
@ -240,17 +239,16 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
return dst return dst
end end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc) function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi)) copy_transpose!(Wi, m.Wi)
copy_transpose!(Wh, Flux.data(m.Wh)) copy_transpose!(Wh, m.Wh)
copy_transpose!(d.bias, Flux.data(m.b)) copy_transpose!(d.bias, m.b)
return return
end end
@ -271,59 +269,58 @@ function desc(rnn)
return d return d
end end
import Flux.Tracker using ..Flux: @adjoint
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
y, h = forward(desc(m), x, h)
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} return h, y
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ? y, h = forward(desc(m), x, h)
track(m, x, h, m.Wi, m.Wh, m.b) : return h, y
forward(desc(m), x, h)
return result[2], result[1]
end end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ? y, h, c = forward(desc(m), x, h[1], h[2])
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : return (h, c), y
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end end
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ) unbroadcast(x::AbstractArray, Δ) =
y, ho = result size(x) == size(Δ) ? Δ :
dy, dho = Δ length(x) == length(Δ) ? trim(x, Δ) :
h_ = hBatch(x, data(h)) trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) for RNN in (CuRNN, CuGRU)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho) = forwardTrain(desc(m), x, h)
(ho, y), function (Δ)
dho, dy = Δ
h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
(dm, unbroadcast(h, dh), dx)
end
end end
end end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, result = forwardTrain(desc(m), data.((x, h, c))...) reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
result, function (Δ) ((ho, co), y), function (Δ)
y, ho = result dhc, dy = Δ
dy, dho, dco = Δ dho, dco = dhc === nothing ? (nothing, nothing) : dhc
h_ = hBatch(x, data(h)) h_ = hBatch(x, h)
c_ = hBatch(x, data(c)) c_ = hBatch(x, c)
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
nobacksies(:RNN, dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
(dx, unbroadcast(h, dh), unbroadcast(c, dc), (dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
transpose(dWi), transpose(dWh), db))
end end
end end

View File

@ -1,14 +1,10 @@
""" """
Iris
Fisher's classic iris dataset. Fisher's classic iris dataset.
Measurements from 3 different species of iris: setosa, versicolor and Measurements from 3 different species of iris: setosa, versicolor and
virginica. There are 50 examples of each species. virginica. There are 50 examples of each species.
There are 4 measurements for each example: sepal length, sepal width, petal There are 4 measurements for each example: sepal length, sepal width, petal
length and petal width. The measurements are in centimeters. length and petal width. The measurements are in centimeters.
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris). The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
@ -35,10 +31,12 @@ end
labels() labels()
Get the labels of the iris dataset, a 150 element array of strings listing the Get the labels of the iris dataset, a 150 element array of strings listing the
species of each example. species of each example.
```jldoctest ```jldoctest
julia> using Flux
julia> labels = Flux.Data.Iris.labels(); julia> labels = Flux.Data.Iris.labels();
julia> summary(labels) julia> summary(labels)
@ -58,11 +56,13 @@ end
features() features()
Get the features of the iris dataset. This is a 4x150 matrix of Float64 Get the features of the iris dataset. This is a 4x150 matrix of Float64
elements. It has a row for each feature (sepal length, sepal width, elements. It has a row for each feature (sepal length, sepal width,
petal length, petal width) and a column for each example. petal length, petal width) and a column for each example.
```jldoctest ```jldoctest
julia> using Flux
julia> features = Flux.Data.Iris.features(); julia> features = Flux.Data.Iris.features();
julia> summary(features) julia> summary(features)
@ -81,6 +81,5 @@ function features()
iris = readdlm(deps("iris.data"), ',') iris = readdlm(deps("iris.data"), ',')
Matrix{Float64}(iris[1:end, 1:4]') Matrix{Float64}(iris[1:end, 1:4]')
end end
end end

2
src/deprecations.jl Normal file
View File

@ -0,0 +1,2 @@
@deprecate param(x) x
@deprecate data(x) x

View File

@ -89,7 +89,7 @@ Dense(W, b) = Dense(W, b, identity)
function Dense(in::Integer, out::Integer, σ = identity; function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros) initW = glorot_uniform, initb = zeros)
return Dense(param(initW(out, in)), param(initb(out)), σ) return Dense(initW(out, in), initb(out), σ)
end end
@treelike Dense @treelike Dense
@ -129,7 +129,7 @@ struct Diagonal{T}
end end
Diagonal(in::Integer; initα = ones, initβ = zeros) = Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in))) Diagonal(initα(in), initβ(in))
@treelike Diagonal @treelike Diagonal
@ -204,7 +204,6 @@ A 'ResNet'-type skip-connection with identity shortcut would simply be
SkipConnection(layer, (a,b) -> a + b) SkipConnection(layer, (a,b) -> a + b)
``` ```
""" """
struct SkipConnection struct SkipConnection
layers layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b connection #user can pass arbitrary connections here, such as (a,b) -> a + b

View File

@ -14,11 +14,11 @@ Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
size = (2,2) size = (2,2)
in = 1 in = 1
out = 16 out = 16
Conv((2, 2), 1=>16, relu) Conv((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches). Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array, In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array. and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. Takes the keyword arguments `pad`, `stride` and `dilation`.
@ -42,7 +42,7 @@ end
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, Conv(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@treelike Conv @treelike Conv
@ -74,8 +74,10 @@ end
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`. Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
struct ConvTranspose{N,M,F,A,V} struct ConvTranspose{N,M,F,A,V}
@ -97,7 +99,7 @@ end
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ, ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose @treelike ConvTranspose
@ -169,8 +171,8 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels" @assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
return DepthwiseConv( return DepthwiseConv(
param(init(k..., div(ch[2], ch[1]), ch[1])), init(k..., div(ch[2], ch[1]), ch[1]),
param(zeros(ch[2])), zeros(ch[2]),
σ; σ;
stride = stride, stride = stride,
pad = pad, pad = pad,
@ -198,25 +200,26 @@ end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x)) a(T.(x))
""" """
CrossCor(size, in=>out) CrossCor(size, in=>out)
CrossCor(size, in=>out, relu) CrossCor(size, in=>out, relu)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`. Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size, Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU. giving us a 16-channel output. Output is activated with ReLU.
size = (2,2) size = (2,2)
in = 1 in = 1
out = 16 out = 16
CrossCor((2, 2), 1=>16, relu) CrossCor((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches). Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array, In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array. and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
struct CrossCor{N,M,F,A,V} struct CrossCor{N,M,F,A,V}
@ -238,7 +241,7 @@ end
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ, CrossCor(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@treelike CrossCor @treelike CrossCor

View File

@ -1,17 +1,20 @@
""" istraining() = false
testmode!(m)
testmode!(m, false)
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode @adjoint istraining() = true, _ -> nothing
(or back to training mode with `false`).
""" _dropout_shape(s, ::Colon) = size(s)
function testmode!(m, val::Bool=true) _dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
prefor(x -> _testmode!(x, val), m)
return m _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
dropout(x, p; dims = :) = x
@adjoint function dropout(x, p; dims = :)
y = rand!(similar(x, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return x .* y, Δ -> (Δ .* y, nothing)
end end
_testmode!(m, test) = nothing
""" """
Dropout(p, dims = :) Dropout(p, dims = :)
@ -19,79 +22,52 @@ A Dropout layer. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted `p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref). used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
Does nothing to the input once in [`testmode!`](@ref).
""" """
mutable struct Dropout{F} mutable struct Dropout{F,D}
p::F p::F
dims::Union{Colon, Int, NTuple{N, Int} where N} dims::D
active::Bool
end end
function Dropout(p; dims = :) function Dropout(p; dims = :)
@assert 0 p 1 @assert 0 p 1
Dropout{typeof(p)}(p, dims, true) Dropout{typeof(p),typeof(dims)}(p, dims)
end end
_dropout_shape(s, ::Colon) = size(s) (a::Dropout)(x) = dropout(x, a.p; dims = a.dims)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) function Base.show(io::IO, d::Dropout)
print(io, "Dropout(", d.p)
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
""" print(io, ")")
dropout(x, p; dims = :)
The dropout function. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training.
"""
function dropout(x, p; dims = :)
y = similar(x, _dropout_shape(x, dims))
rand!(y)
y .= _dropout_kernel.(y, p, 1 - p)
return x .* y
end end
function (a::Dropout)(x)
a.active || return x
return dropout(x, a.p; dims = a.dims)
end
_testmode!(a::Dropout, test) = (a.active = !test)
""" """
AlphaDropout(p) AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks. A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before. The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
""" """
mutable struct AlphaDropout{F} mutable struct AlphaDropout{F}
p::F p::F
active::Bool function AlphaDropout(p)
end @assert 0 p 1
new{typeof(p)}(p)
function AlphaDropout(p) end
@assert 0 p 1
AlphaDropout(p,true)
end end
function (a::AlphaDropout)(x) function (a::AlphaDropout)(x)
a.active || return x istraining() || return x
λ = eltype(x)(1.0507009873554804934193349852946) λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717) α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α) α1 = eltype(x)(-λ*α)
noise = randn(eltype(x), size(x)) noise = randn(eltype(x), size(x))
x = @. x*(noise > (1 - a.p)) + α1 * (noise <= (1 - a.p)) x = @. x*(noise > (1 - a.p)) + α1 * (noise < (1 - a.p))
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5 A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
B = -A * α1 * (1 - a.p) B = -A * α1 * (1 - a.p)
x = @. A * x + B x = @. A * x + B
return x return x
end end
_testmode!(a::AlphaDropout, test) = (a.active = !test)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)
@ -151,25 +127,23 @@ mutable struct BatchNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Bool
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum)
function (BN::BatchNorm)(x) function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) || size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
dims = length(size(x)) dims = length(size(x))
channels = size(x, dims-1) channels = size(x, dims-1)
affine_shape = ones(Int, dims) affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
affine_shape[end-1] = channels m = div(prod(size(x)), channels)
m = prod(size(x)[1:end-2]) * size(x)[end]
γ = reshape(BN.γ, affine_shape...) γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...) β = reshape(BN.β, affine_shape...)
if !BN.active if !istraining()
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ ϵ = BN.ϵ
@ -178,11 +152,12 @@ function (BN::BatchNorm)(x)
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes) μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
ϵ = data(convert(T, BN.ϵ)) ϵ = convert(T, BN.ϵ)
# update moving mean/std # update moving mean/std
mtm = data(convert(T, BN.momentum)) mtm = BN.momentum
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) S = eltype(BN.μ)
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :))
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :))
end end
let λ = BN.λ let λ = BN.λ
@ -192,12 +167,10 @@ function (BN::BatchNorm)(x)
end end
children(BN::BatchNorm) = children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active) (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active) BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
function Base.show(io::IO, l::BatchNorm) function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))") print(io, "BatchNorm($(join(size(l.β), ", "))")
@ -244,13 +217,12 @@ mutable struct InstanceNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Bool
end end
InstanceNorm(chs::Integer, λ = identity; InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum)
function (in::InstanceNorm)(x) function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) || size(x, ndims(x)-1) == length(in.β) ||
@ -261,28 +233,26 @@ function (in::InstanceNorm)(x)
dims = length(size(x)) dims = length(size(x))
c = size(x, dims-1) c = size(x, dims-1)
bs = size(x, dims) bs = size(x, dims)
affine_shape = ones(Int, dims) affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x))
affine_shape[end-1] = c m = div(prod(size(x)), c*bs)
affine_shape[end] = bs
m = prod(size(x)[1:end-2])
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
if !in.active if !istraining()
μ = expand_inst(in.μ, affine_shape) μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape) σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ ϵ = in.ϵ
else else
T = eltype(x) T = eltype(x)
ϵ = data(convert(T, in.ϵ)) ϵ = convert(T, in.ϵ)
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
μ = mean(x, dims = axes) μ = mean(x, dims = axes)
σ² = mean((x .- μ) .^ 2, dims = axes) σ² = mean((x .- μ) .^ 2, dims = axes)
S = eltype(in.μ)
# update moving mean/std # update moving mean/std
mtm = data(convert(T, in.momentum)) mtm = in.momentum
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2)
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2)
end end
let λ = in.λ let λ = in.λ
@ -292,12 +262,10 @@ function (in::InstanceNorm)(x)
end end
children(in::InstanceNorm) = children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active) (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active) InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
_testmode!(in::InstanceNorm, test) = (in.active = !test)
function Base.show(io::IO, l::InstanceNorm) function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))") print(io, "InstanceNorm($(join(size(l.β), ", "))")
@ -306,11 +274,11 @@ function Base.show(io::IO, l::InstanceNorm)
end end
""" """
Group Normalization. Group Normalization.
This layer can outperform Batch-Normalization and Instance-Normalization. This layer can outperform Batch-Normalization and Instance-Normalization.
GroupNorm(chs::Integer, G::Integer, λ = identity; GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0) ϵ = 1f-5, momentum = 0.1f0)
``chs`` is the number of channels, the channel dimension of your input. ``chs`` is the number of channels, the channel dimension of your input.
@ -322,12 +290,11 @@ The number of channels must be an integer multiple of the number of groups.
Example: Example:
``` ```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
``` ```
Link : https://arxiv.org/pdf/1803.08494.pdf Link : https://arxiv.org/pdf/1803.08494.pdf
""" """
mutable struct GroupNorm{F,V,W,N,T} mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups G::T # number of groups
λ::F # activation function λ::F # activation function
@ -337,13 +304,12 @@ mutable struct GroupNorm{F,V,W,N,T}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Bool
end end
GroupNorm(chs::Integer, G::Integer, λ = identity; GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)), GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum, true) zeros(G,1), ones(G,1), ϵ, momentum)
function(gn::GroupNorm)(x) function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
@ -355,20 +321,17 @@ function(gn::GroupNorm)(x)
channels = size(x, dims-1) channels = size(x, dims-1)
batches = size(x,dims) batches = size(x,dims)
channels_per_group = div(channels,groups) channels_per_group = div(channels,groups)
affine_shape = ones(Int, dims) affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
# Output reshaped to (W,H...,C/G,G,N) # Output reshaped to (W,H...,C/G,G,N)
affine_shape[end-1] = channels μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1)
μ_affine_shape = ones(Int,dims + 1)
μ_affine_shape[end-1] = groups
m = prod(size(x)[1:end-2]) * channels_per_group m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...) γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...) β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active if !istraining()
og_shape = size(x) og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
@ -379,31 +342,29 @@ function(gn::GroupNorm)(x)
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis) axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes) μ = mean(y, dims = axes)
σ² = mean((y .- μ) .^ 2, dims = axes) σ² = mean((y .- μ) .^ 2, dims = axes)
ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2) ϵ = convert(T, gn.ϵ)
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2) # update moving mean/std
mtm = gn.momentum
S = eltype(gn.μ)
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2)
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2)
end end
let λ = gn.λ let λ = gn.λ
= (y .- μ) ./ sqrt.(σ² .+ ϵ) = (y .- μ) ./ sqrt.(σ² .+ ϵ)
# Reshape x̂ # Reshape x̂
= reshape(,og_shape) = reshape(,og_shape)
λ.(γ .* .+ β) λ.(γ .* .+ β)
end end
end end
children(gn::GroupNorm) = children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active) (gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum)
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN) mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active) GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum)
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
function Base.show(io::IO, l::GroupNorm) function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))") print(io, "GroupNorm($(join(size(l.β), ", "))")

View File

@ -42,21 +42,6 @@ end
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
_truncate(x::AbstractArray) = Tracker.data(x)
_truncate(x::Tuple) = _truncate.(x)
"""
truncate!(rnn)
Truncates the gradient of the hidden state in recurrent layers. The value of the
state is preserved. See also `reset!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = Tracker.data(rnn.state)
"""
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
""" """
reset!(rnn) reset!(rnn)
@ -83,8 +68,8 @@ end
RNNCell(in::Integer, out::Integer, σ = tanh; RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) = init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)), RNNCell(σ, init(out, in), init(out, out),
param(init(out)), param(zeros(out))) init(out), zeros(out))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
@ -122,9 +107,9 @@ end
function LSTMCell(in::Integer, out::Integer; function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform) init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)), cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
param(zeros(out)), param(zeros(out))) zeros(out), zeros(out))
cell.b.data[gate(out, 2)] .= 1 cell.b[gate(out, 2)] .= 1
return cell return cell
end end
@ -168,8 +153,8 @@ mutable struct GRUCell{A,V}
end end
GRUCell(in, out; init = glorot_uniform) = GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)), GRUCell(init(out * 3, in), init(out * 3, out),
param(init(out*3)), param(zeros(out))) init(out * 3), zeros(out))
function (m::GRUCell)(h, x) function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1) b, o = m.b, size(h, 1)

View File

@ -49,8 +49,3 @@ function normalise(x::AbstractArray; dims=1)
σ = std(x, dims = dims, mean = μ′, corrected=false) σ = std(x, dims = dims, mean = μ′, corrected=false)
return (x .- μ′) ./ σ return (x .- μ′) ./ σ
end end
function normalise(x::AbstractArray, dims)
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
normalise(x, dims = dims)
end

View File

@ -54,17 +54,19 @@ it will error.
## Examples ## Examples
```jldoctest ```jldoctest
julia> using Flux: onehot
julia> onehot(:b, [:a, :b, :c]) julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector: 3-element Flux.OneHotVector:
false 0
true 1
false 0
julia> onehot(:c, [:a, :b, :c]) julia> onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector: 3-element Flux.OneHotVector:
false 0
false 0
true 1
``` ```
""" """
function onehot(l, labels) function onehot(l, labels)
@ -88,12 +90,13 @@ Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `label
## Examples ## Examples
```jldoctest ```jldoctest
julia> onehotbatch([:b, :a, :b], [:a, :b, :c]) julia> using Flux: onehotbatch
3×3 Flux.OneHotMatrix:
false true false
true false true
false false false
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
0 1 0
1 0 1
0 0 0
``` ```
""" """
onehotbatch(ls, labels, unk...) = onehotbatch(ls, labels, unk...) =
@ -106,9 +109,9 @@ Base.argmax(xs::OneHotVector) = xs.ix
Inverse operations of [`onehot`](@ref). Inverse operations of [`onehot`](@ref).
## Examples
```jldoctest ```jldoctest
julia> using Flux: onecold
julia> onecold([true, false, false], [:a, :b, :c]) julia> onecold([true, false, false], [:a, :b, :c])
:a :a
@ -124,15 +127,6 @@ onecold(y::AbstractMatrix, labels...) =
onecold(y::OneHotMatrix, labels...) = onecold(y::OneHotMatrix, labels...) =
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
function argmax(xs...) # TODO probably still want this as a custom adjoint Zygote
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax) # onecold(x::TrackedVector, l...) = onecold(data(x), l...)
return onecold(xs...) # onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
end
# Ambiguity hack
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)

View File

@ -7,6 +7,5 @@ export train!,
include("optimisers.jl") include("optimisers.jl")
include("train.jl") include("train.jl")
include("deprecations.jl")
end end

View File

@ -1,126 +0,0 @@
using Base: depwarn
using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> _update_params!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
ps = params
opt = Descent(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
opt = Momentum(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
opt = Nesterov(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
opt = RMSProp(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
β = (β1, β2)
opt = ADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
opt = ADAGrad(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
opt = ADADelta(ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
β = (β1, β2)
opt = AdaMax(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
β = (β1, β2)
opt = AMSGrad(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
β = (β1, β2)
opt = NADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params
β = (β1, β2)
opt = ADAMW(η, β)
opt = check_decay(opt, decay)
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
updaterule(opt, ps)
end
# Old training loop
struct OldOptimiser
func
end
_update_params!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
train!(loss, (), data, OldOptimiser(opt); cb = cb)
end

View File

@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function apply!(o::Momentum, x, Δ) function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(data(x)) v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ @. v = ρ * v - η * Δ
@. Δ = -v @. Δ = -v
end end
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function apply!(o::Nesterov, x, Δ) function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(data(x)) v = get!(o.velocity, x, zero(x))::typeof(x)
d = @. ρ^2 * v - (1+ρ) * η * Δ d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ @. v = ρ*v - η*Δ
@. Δ = -d @. Δ = -d
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function apply!(o::RMSProp, x, Δ) function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(data(x)) acc = get!(o.acc, x, zero(x))::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ) @. Δ *= η / (acc + ϵ)
end end
@ -177,7 +177,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function apply!(o::ADAGrad, x, Δ) function apply!(o::ADAGrad, x, Δ)
η = o.eta η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x)) acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2 @. acc += Δ^2
@. Δ *= η / (acc + ϵ) @. Δ *= η / (acc + ϵ)
end end
@ -352,5 +352,5 @@ WeightDecay() = WeightDecay(0)
function apply!(o::WeightDecay, x, Δ) function apply!(o::WeightDecay, x, Δ)
wd = o.wd wd = o.wd
@. Δ += wd * data(x) @. Δ += wd * x
end end

View File

@ -1,32 +1,29 @@
using Juno using Juno
import Flux.Tracker: Params, gradient, data, update! import Zygote: Params, gradient
import Base.depwarn
function update!(x::AbstractArray, )
x .+=
return x
end
function update!(opt, x, ) function update!(opt, x, )
update!(x, -apply!(opt, x, data())) x .-= apply!(opt, x, )
end end
function update!(opt, xs::Params, gs) function update!(opt, xs::Params, gs)
for x in xs for x in xs
gs[x] == nothing && continue
update!(opt, x, gs[x]) update!(opt, x, gs[x])
end end
end end
# Added as an internal API but everyone started using it.
function _update_params!(opt, xs)
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
for x in xs
update!(opt, x, Tracker.grad(x))
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
end
end
# Callback niceties # Callback niceties
call(f, xs...) = f(xs...) call(f, xs...) = f(xs...)
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
struct StopException <: Exception end struct StopException <: Exception end
""" """
stop() stop()
@ -72,10 +69,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
loss(d...) loss(d...)
end end
update!(opt, ps, gs) update!(opt, ps, gs)
if cb() == :stop cb()
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
end
catch ex catch ex
if ex isa StopException if ex isa StopException
break break

View File

@ -1,5 +1,5 @@
import Adapt: adapt, adapt_storage import Adapt: adapt, adapt_storage
import .Tracker: IdSet import Zygote: IdSet
children(x) = () children(x) = ()
mapchildren(f, x) = x mapchildren(f, x) = x
@ -40,7 +40,7 @@ end
function params(m) function params(m)
ps = Params() ps = Params()
prefor(p -> prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) && p isa AbstractArray{<:Real} &&
!any(p -> p === p, ps) && push!(ps, p), !any(p -> p === p, ps) && push!(ps, p),
m) m)
return ps return ps
@ -52,7 +52,7 @@ function loadparams!(m, xs)
for (p, x) in zip(params(m), xs) for (p, x) in zip(params(m), xs)
size(p) == size(x) || size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))") error("Expected param size $(size(p)), got $(size(x))")
copyto!(data(p), data(x)) copyto!(p, x)
end end
end end
@ -81,8 +81,6 @@ f64(m) = paramtype(Float64, m)
function mapparams(f, m) function mapparams(f, m)
mapleaves(m) do x mapleaves(m) do x
Tracker.istracked(x) ? param(f(Tracker.data(x))) : x isa Union{AbstractArray,Number} ? f(x) : x
x isa Union{AbstractArray,Number} ? f(x) :
x
end end
end end

View File

@ -1,4 +1,4 @@
using Flux, Flux.Tracker, CuArrays, Test using Flux, CuArrays, Test
using Flux: gpu using Flux: gpu
@info "Testing GPU Support" @info "Testing GPU Support"
@ -7,11 +7,11 @@ using Flux: gpu
CuArrays.allowscalar(false) CuArrays.allowscalar(false)
x = param(randn(5, 5)) x = randn(5, 5)
cx = gpu(x) cx = gpu(x)
@test cx isa TrackedArray && cx.data isa CuArray @test cx isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3 @test Flux.onecold(gpu([1.0, 2.0, 3.0])) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3) x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x) cx = gpu(x)
@ -21,24 +21,26 @@ cx = gpu(x)
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
cm = gpu(m) cm = gpu(m)
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm)) @test all(p isa CuArray for p in params(cm))
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}} @test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
x = [1,2,3] x = [1,2,3]
cx = gpu(x) cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx) @test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
xs = param(rand(5,5)) xs = rand(5, 5)
ys = Flux.onehotbatch(1:5,1:5) ys = Flux.onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys) @test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
c = gpu(Conv((2,2),3=>4)) c = gpu(Conv((2,2),3=>4))
x = gpu(rand(10, 10, 3, 2))
l = c(gpu(rand(10,10,3,2))) l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l)) @test gradient(x -> sum(c(x)), x)[1] isa CuArray
c = gpu(CrossCor((2,2),3=>4)) c = gpu(CrossCor((2,2),3=>4))
x = gpu(rand(10, 10, 3, 2))
l = c(gpu(rand(10,10,3,2))) l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l)) @test gradient(x -> sum(c(x)), x)[1] isa CuArray
end end

View File

@ -1,48 +1,44 @@
using Flux, Flux.Tracker, CuArrays, Test using Flux, CuArrays, Test
using Flux.Tracker: TrackedArray, data using Flux: forward
@testset "CUDNN BatchNorm" begin @testset "CUDNN BatchNorm" begin
@testset "4D Input" begin @testset "4D Input" begin
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1)))) x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
m = BatchNorm(3) m = BatchNorm(3)
cx = gpu(x) cx = gpu(x)
cm = gpu(m) cm = gpu(m)
y = m(x) y, back = forward((m, x) -> m(x), m, x)
cy = cm(cx) cy, cback = forward((m, x) -> m(x), cm, cx)
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}} @test cpu(cy) y
@test cpu(data(cy)) data(y) Δ = randn(size(y))
dm, dx = back(Δ)
cdm, cdx = cback(gpu(Δ))
g = rand(size(y)...) @test dm[].γ cpu(cdm[].γ)
Flux.back!(y, g) @test dm[].β cpu(cdm[].β)
Flux.back!(cy, gpu(g)) @test dx cpu(cdx)
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end end
@testset "2D Input" begin @testset "2D Input" begin
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4)))) x = Float64.(collect(reshape(1:12, 3, 4)))
m = BatchNorm(3) m = BatchNorm(3)
cx = gpu(x) cx = gpu(x)
cm = gpu(m) cm = gpu(m)
y = m(x) y, back = forward((m, x) -> m(x), m, x)
cy = cm(cx) cy, cback = forward((m, x) -> m(x), cm, cx)
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}} @test cpu(cy) y
@test cpu(data(cy)) data(y) Δ = randn(size(y))
dm, dx = back(Δ)
cdm, cdx = cback(gpu(Δ))
g = rand(size(y)...) @test dm[].γ cpu(cdm[].γ)
Flux.back!(y, g) @test dm[].β cpu(cdm[].β)
Flux.back!(cy, gpu(g)) @test dx cpu(cdx)
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end end
end end

View File

@ -1,46 +1,54 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Flux: forward
@testset "RNN" begin @testset "RNN" begin
@testset for R in [RNN, GRU, LSTM] @testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
rnn = R(10, 5) rnn = R(10, 5)
curnn = mapleaves(gpu, rnn) curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
@test y.data collect(cuy.data) Flux.reset!(rnn)
@test haskey(Flux.CUDA.descs, curnn.cell) Flux.reset!(curnn)
x = batch_size == 1 ?
rand(10) :
rand(10, batch_size)
cux = gpu(x)
Δ = randn(size(y)) y, back = forward((r, x) -> (r(x)), rnn, x)
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux)
Flux.back!(y, Δ) @test y collect(cuy)
Flux.back!(cuy, gpu(Δ)) @test haskey(Flux.CUDA.descs, curnn.cell)
@test x.grad collect(cux.grad) = randn(size(y))
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad) , = back()
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad) cum̄, cux̄ = cuback(gpu())
@test rnn.cell.b.grad collect(curnn.cell.b.grad)
@test rnn.cell.h.grad collect(curnn.cell.h.grad) [].cell[].Wi
if isdefined(rnn.cell, :c)
@test rnn.cell.c.grad collect(curnn.cell.c.grad) [].state
cum̄[].state
@test collect(cux̄)
@test [].cell[].Wi collect(cum̄[].cell[].Wi)
@test [].cell[].Wh collect(cum̄[].cell[].Wh)
@test [].cell[].b collect(cum̄[].cell[].b)
if [].state isa Tuple
for (x, cx) in zip([].state, cum̄[].state)
@test x collect(cx)
end end
else
Flux.reset!(rnn) @test [].state collect(cum̄[].state)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y.data collect(cuy.data)
end end
Flux.reset!(rnn)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y collect(cuy)
end end
end end

View File

@ -25,9 +25,9 @@ end
@testset "asymmetric padding" begin @testset "asymmetric padding" begin
r = ones(Float32, 28, 28, 1, 1) r = ones(Float32, 28, 28, 1, 1)
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2)) m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
m.weight.data[:] .= 1.0 m.weight[:] .= 1.0
m.bias.data[:] .= 0.0 m.bias[:] .= 0.0
y_hat = Flux.data(m(r))[:,:,1,1] y_hat = m(r)[:,:,1,1]
@test size(y_hat) == (27, 29) @test size(y_hat) == (27, 29)
@test y_hat[1, 1] 6.0 @test y_hat[1, 1] 6.0
@test y_hat[2, 2] 9.0 @test y_hat[2, 2] 9.0
@ -41,7 +41,7 @@ end
r = zeros(Float32, 28, 28, 3, 5) r = zeros(Float32, 28, 28, 3, 5)
m1 = DepthwiseConv((2, 2), 3=>15) m1 = DepthwiseConv((2, 2), 3=>15)
@test size(m1(r), 3) == 15 @test size(m1(r), 3) == 15
m3 = DepthwiseConv((2, 3), 3=>9) m3 = DepthwiseConv((2, 3), 3=>9)
@test size(m3(r), 3) == 9 @test size(m3(r), 3) == 9
@ -62,7 +62,7 @@ end
y = CrossCor(w, [0.0]) y = CrossCor(w, [0.0])
@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1] @test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]
r = zeros(Float32, 28, 28, 1, 5) r = zeros(Float32, 28, 28, 1, 5)
m = Chain( m = Chain(
CrossCor((2, 2), 1=>16, relu), CrossCor((2, 2), 1=>16, relu),
@ -102,4 +102,3 @@ end
true true
end end
end end

View File

@ -1,29 +1,29 @@
using Flux: testmode! using Flux, Test, Statistics
using Flux.Tracker: data using Zygote: forward
trainmode(f, x...) = forward(f, x...)[1]
trainmode(f) = (x...) -> trainmode(f, x...)
@testset "Dropout" begin @testset "Dropout" begin
x = [1.,2.,3.] x = [1.,2.,3.]
@test x == testmode!(Dropout(0.1))(x) @test x == Dropout(0.1)(x)
@test x == Dropout(0)(x) @test x == trainmode(Dropout(0), x)
@test zero(x) == Dropout(1)(x) @test zero(x) == trainmode(Dropout(1), x)
x = rand(100) x = rand(100)
m = Dropout(0.9) m = Dropout(0.9)
y = m(x) y = trainmode(m, x)
@test count(a->a==0, y) > 50 @test count(a->a==0, y) > 50
testmode!(m)
y = m(x) y = m(x)
@test count(a->a==0, y) == 0 @test count(a->a==0, y) == 0
testmode!(m, false) y = trainmode(m, x)
y = m(x)
@test count(a->a==0, y) > 50 @test count(a->a==0, y) > 50
x = rand(100) x = rand(Float32, 100)
m = Chain(Dense(100,100), m = Chain(Dense(100,100),
Dropout(0.9)) Dropout(0.9))
y = m(x) y = trainmode(m, x)
@test count(a->a == 0, y) > 50 @test count(a->a == 0, y) > 50
testmode!(m)
y = m(x) y = m(x)
@test count(a->a == 0, y) == 0 @test count(a->a == 0, y) == 0
@ -39,18 +39,16 @@ using Flux.Tracker: data
end end
@testset "BatchNorm" begin @testset "BatchNorm" begin
let m = BatchNorm(2), x = param([1 3 5; let m = BatchNorm(2), x = [1.0 3.0 5.0;
2 4 6]) 2.0 4.0 6.0]
@test m.β.data == [0, 0] # initβ(2) @test m.β == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2) @test m.γ == [1, 1] # initγ(2)
# initial m.σ is 1 # initial m.σ is 1
# initial m.μ is 0 # initial m.μ is 0
@test m.active
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
m(x)
y = trainmode(m, x)
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
# julia> x # julia> x
# 2×3 Array{Float64,2}: # 2×3 Array{Float64,2}:
# 1.0 3.0 5.0 # 1.0 3.0 5.0
@ -69,41 +67,32 @@ end
# 2×1 Array{Float64,2}: # 2×1 Array{Float64,2}:
# 1.3 # 1.3
# 1.3 # 1.3
@test m.σ² .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] @test m.σ² .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m) x = m(x)
@test !m.active
x = m(x).data
@test isapprox(x[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) @test isapprox(x[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
end end
# with activation function # with activation function
let m = BatchNorm(2, sigmoid), x = param([1 3 5; let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
2 4 6]) 2.0 4.0 6.0]
@test m.active y = m(x)
m(x) @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
testmode!(m)
@test !m.active
y = m(x).data
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
end end
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1)) let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y @test m(x) == y
end end
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1)) let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y @test m(x) == y
end end
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1)) let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y @test m(x) == y
@ -115,20 +104,16 @@ end
end end
end end
@testset "InstanceNorm" begin @testset "InstanceNorm" begin
# helper functions # helper functions
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests # begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2), let m = InstanceNorm(2), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = reshape(collect(1:prod(sizes)), sizes)
x = Float64.(x)
@test m.β.data == [0, 0] # initβ(2) @test m.β == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2) @test m.γ == [1, 1] # initγ(2)
y = trainmode(m, x)
@test m.active
m(x)
#julia> x #julia> x
#[:, :, 1] = #[:, :, 1] =
@ -153,37 +138,28 @@ end
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
@test m.μ [0.5, 0.8] @test m.μ [0.5, 0.8]
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq # momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. # julia> reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
# 2-element Array{Float64,1}: # 2-element Array{Float64,1}:
# 1. # 1.
# 1. # 1.
@test m.σ² reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. @test m.σ² reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
testmode!(m) x = m(x)
@test !m.active
x = m(x).data
@test isapprox(x[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5) @test isapprox(x[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
end end
# with activation function # with activation function
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = reshape(collect(1:prod(sizes)), sizes)
x = Float64.(x)
affine_shape = collect(sizes) affine_shape = collect(sizes)
affine_shape[1] = 1 affine_shape[1] = 1
@test m.active y = m(x)
m(x) @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
testmode!(m)
@test !m.active
y = m(x).data
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
end end
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3), let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...) y = reshape(m(y), sizes...)
@test m(x) == y @test m(x) == y
@ -191,16 +167,16 @@ end
# check that μ, σ², and the output are the correct size for higher rank tensors # check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = reshape(Float32.(collect(1:prod(sizes))), sizes)
y = m(x) y = trainmode(m, x)
@test size(m.μ) == (sizes[end - 1], ) @test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes @test size(y) == sizes
end end
# show that instance norm is equal to batch norm when channel and batch dims are squashed # show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6), let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = reshape(Float32.(collect(1:prod(sizes))), sizes)
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end end
@ -216,14 +192,12 @@ end
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2), let m = GroupNorm(4,2), sizes = (3,4,2),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = reshape(collect(1:prod(sizes)), sizes)
x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32)
@test m.β.data == [0, 0, 0, 0] # initβ(32) y = trainmode(m, x)
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
@test m.active
m(x)
#julia> x #julia> x
#[:, :, 1] = #[:, :, 1] =
@ -243,7 +217,7 @@ end
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5 # (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5 # (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
# #
# μ = # μ =
# 3.5 15.5 # 3.5 15.5
# 9.5 21.5 # 9.5 21.5
# #
@ -253,46 +227,37 @@ end
@test m.μ [0.95, 1.55] @test m.μ [0.95, 1.55]
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1. # julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
# 2-element Array{Tracker.TrackedReal{Float64},1}: # 2-element Array{Float64,1}:
# 1.25 # 1.25
# 1.25 # 1.25
@test m.σ² mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1. @test m.σ² mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
testmode!(m) x = m(x)
@test !m.active
x = m(x).data
@test isapprox(x[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5) @test isapprox(x[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
end end
# with activation function # with activation function
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2), let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = reshape(collect(1:prod(sizes)), sizes)
x = Float64.(x)
μ_affine_shape = ones(Int,length(sizes) + 1) μ_affine_shape = ones(Int,length(sizes) + 1)
μ_affine_shape[end-1] = 2 # Number of groups μ_affine_shape[end-1] = 2 # Number of groups
affine_shape = ones(Int,length(sizes) + 1) affine_shape = ones(Int,length(sizes) + 1)
affine_shape[end-2] = 2 # Channels per group affine_shape[end-2] = 2 # Channels per group
affine_shape[end-1] = 2 # Number of groups affine_shape[end-1] = 2 # Number of groups
affine_shape[1] = sizes[1] affine_shape[1] = sizes[1]
affine_shape[end] = sizes[end] affine_shape[end] = sizes[end]
og_shape = size(x) og_shape = size(x)
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x) y = m(x)
x_ = reshape(x,affine_shape...) x_ = reshape(x,affine_shape...)
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape) out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape)
@test isapprox(y, out, atol = 1.0e-7) @test isapprox(y, out, atol = 1.0e-7)
end end
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3), let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...) y = reshape(m(y), sizes...)
@test m(x) == y @test m(x) == y
@ -300,22 +265,22 @@ end
# check that μ, σ², and the output are the correct size for higher rank tensors # check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6), let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = m(x) y = trainmode(m, x)
@test size(m.μ) == (m.G,1) @test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1) @test size(m.σ²) == (m.G,1)
@test size(y) == sizes @test size(y) == sizes
end end
# show that group norm is the same as instance norm when the group size is the same as the number of channels # show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5), let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x) @test IN(x) GN(x)
end end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1 # show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1), let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
x = param(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x) @test BN(x) GN(x)
end end

View File

@ -51,13 +51,13 @@ const ϵ = 1e-7
end end
@testset "no spurious promotions" begin @testset "no spurious promotions" begin
for T in (Float16, Float32, Float64) for T in (Float32, Float64)
y = rand(T, 2) y = rand(T, 2)
ŷ = rand(T, 2) ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy) for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.Tracker.forward(mse, , y) fwd, back = Flux.forward(f, , y)
@test typeof(fwd) == Flux.Tracker.TrackedReal{T} @test fwd isa T
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T} @test eltype(back(one(T))[1]) == T
end end
end end
end end

View File

@ -1,42 +1,44 @@
using Flux.Optimise using Flux.Optimise
using Flux.Optimise: runall using Flux.Optimise: runall
using Flux.Tracker using Flux: Params, gradient
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(), NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()] Momentum()]
w = param(randn(10, 10)) w = randn(10, 10)
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
for t = 1: 10^5 for t = 1: 10^5
θ = Params([w]) θ = Params([w])
θ̄ = gradient(() -> loss(rand(10)), θ) x = rand(10)
θ̄ = gradient(() -> loss(x), θ)
Optimise.update!(opt, θ, θ̄) Optimise.update!(opt, θ, θ̄)
end end
@test Flux.mse(w, w) < 0.01 @test loss(rand(10, 10)) < 0.01
end end
end end
@testset "Optimiser" begin @testset "Optimiser" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [InvDecay, WeightDecay, ExpDecay] @testset for Opt in [InvDecay, WeightDecay, ExpDecay]
w = param(randn(10, 10)) w = randn(10, 10)
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Optimiser(Opt(), ADAM(0.001)) opt = Optimiser(Opt(), ADAM(0.001))
for t = 1:10^5 for t = 1:10^5
l = loss(rand(10)) θ = Params([w])
back!(l) x = rand(10)
delta = Optimise.apply!(opt, w.data, w.grad) θ̄ = gradient(() -> loss(x), θ)
w.data .-= delta Optimise.update!(opt, θ, θ̄)
end end
@test Flux.mse(w, w) < 0.01 @test loss(rand(10, 10)) < 0.01
end end
end end
@testset "Training Loop" begin @testset "Training Loop" begin
i = 0 i = 0
l = param(1) l = 1
Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
(), (),
@ -57,17 +59,18 @@ end
@testset "ExpDecay" begin @testset "ExpDecay" begin
w = randn(10, 10) w = randn(10, 10)
o = ExpDecay(0.1, 0.1, 1000, 1e-4) o = ExpDecay(0.1, 0.1, 1000, 1e-4)
w1 = param(randn(10,10)) w1 = randn(10,10)
loss(x) = Flux.mse(w*x, w1*x) loss(x) = Flux.mse(w*x, w1*x)
flag = 1 flag = 1
decay_steps = [] decay_steps = []
for t = 1:10^5 for t = 1:10^5
l = loss(rand(10))
back!(l)
prev_eta = o.eta prev_eta = o.eta
prev_grad = collect(w1.grad) θ = Params([w1])
delta = Optimise.apply!(o, w1.data, w1.grad) x = rand(10)
w1.data .-= delta θ̄ = gradient(() -> loss(x), θ)
prev_grad = collect(θ̄[w1])
delta = Optimise.apply!(o, w1, θ̄[w1])
w1 .-= delta
new_eta = o.eta new_eta = o.eta
if new_eta != prev_eta if new_eta != prev_eta
push!(decay_steps, t) push!(decay_steps, t)

View File

@ -1,11 +1,8 @@
using Flux, Test, Random, Statistics using Flux, Test, Random, Statistics, Documenter
using Random using Random
Random.seed!(0) Random.seed!(0)
# So we can use the system CuArrays
insert!(LOAD_PATH, 2, "@v#.#")
@testset "Flux" begin @testset "Flux" begin
@info "Testing Basics" @info "Testing Basics"
@ -22,14 +19,14 @@ include("layers/normalisation.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/conv.jl") include("layers/conv.jl")
@info "Running Gradient Checks"
include("tracker.jl")
if isdefined(Flux, :CUDA) if isdefined(Flux, :CUDA)
include("cuda/cuda.jl") include("cuda/cuda.jl")
else else
@warn "CUDA unavailable, not testing GPU support" @warn "CUDA unavailable, not testing GPU support"
end end
if VERSION >= v"1.2"
doctest(Flux)
end
end end

View File

@ -1,15 +0,0 @@
using Flux, Test
using Tracker: gradcheck
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "Tracker" begin
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> Flux.normalise(x), rand(4,3))
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
end

View File

@ -1,5 +1,5 @@
using Flux using Flux
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack using Flux: throttle, glorot_uniform, glorot_normal, stack, unstack
using StatsBase: std using StatsBase: std
using Random using Random
using Test using Test
@ -52,15 +52,6 @@ using Test
end end
end end
@testset "Jacobian" begin
A = param(randn(2,2))
x = randn(2)
m(x) = A*x
y = m(x)
J = jacobian(m,x)
@test J A.data
end
@testset "Initialization" begin @testset "Initialization" begin
# Set random seed so that these tests don't fail randomly # Set random seed so that these tests don't fail randomly
Random.seed!(0) Random.seed!(0)
@ -106,12 +97,11 @@ end
@testset "Precision" begin @testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2)) m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10) x = rand(10)
@test eltype(m[1].W.data) == Float32 @test eltype(m[1].W) == Float32
@test eltype(m(x).data) == Float32 @test eltype(m(x)) == Float32
@test eltype(f64(m)(x).data) == Float64 @test eltype(f64(m)(x)) == Float64
@test eltype(f64(m)[1].W.data) == Float64 @test eltype(f64(m)[1].W) == Float64
@test eltype(f32(f64(m))[1].W.data) == Float32 @test eltype(f32(f64(m))[1].W) == Float32
@test Tracker.isleaf(f32(f64(m))[1].W)
end end
@testset "Stacking" begin @testset "Stacking" begin