commit
bdeb9c6d58
@ -6,7 +6,7 @@ os:
|
|||||||
# - osx
|
# - osx
|
||||||
|
|
||||||
julia:
|
julia:
|
||||||
- 1.0
|
- 1.1
|
||||||
- nightly
|
- nightly
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -174,6 +174,12 @@ git-tree-sha1 = "dd169c636d1d3656a9faca772f5bd7c226a61254"
|
|||||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||||
version = "1.0.1"
|
version = "1.0.1"
|
||||||
|
|
||||||
|
[[IRTools]]
|
||||||
|
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||||
|
git-tree-sha1 = "e23faa71b8f54c3fdc99b230b9c2906cafdddca5"
|
||||||
|
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||||
|
version = "0.2.3"
|
||||||
|
|
||||||
[[InteractiveUtils]]
|
[[InteractiveUtils]]
|
||||||
deps = ["Markdown"]
|
deps = ["Markdown"]
|
||||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
@ -226,10 +232,9 @@ uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
|||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
|
||||||
[[Missings]]
|
[[Missings]]
|
||||||
deps = ["SparseArrays", "Test"]
|
git-tree-sha1 = "29858ce6c8ae629cf2d733bffa329619a1c843d0"
|
||||||
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
|
|
||||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
|
|
||||||
[[Mmap]]
|
[[Mmap]]
|
||||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||||
@ -254,9 +259,9 @@ version = "1.1.0"
|
|||||||
|
|
||||||
[[Parsers]]
|
[[Parsers]]
|
||||||
deps = ["Dates", "Test"]
|
deps = ["Dates", "Test"]
|
||||||
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
|
git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
|
||||||
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
||||||
version = "0.3.6"
|
version = "0.3.7"
|
||||||
|
|
||||||
[[Pkg]]
|
[[Pkg]]
|
||||||
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||||
@ -314,10 +319,10 @@ deps = ["LinearAlgebra", "Random"]
|
|||||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||||
|
|
||||||
[[SpecialFunctions]]
|
[[SpecialFunctions]]
|
||||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
deps = ["BinDeps", "BinaryProvider", "Libdl"]
|
||||||
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
|
git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
|
||||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||||
version = "0.7.2"
|
version = "0.8.0"
|
||||||
|
|
||||||
[[StaticArrays]]
|
[[StaticArrays]]
|
||||||
deps = ["LinearAlgebra", "Random", "Statistics"]
|
deps = ["LinearAlgebra", "Random", "Statistics"]
|
||||||
@ -350,12 +355,6 @@ git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
|
|||||||
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
||||||
version = "0.5.6"
|
version = "0.5.6"
|
||||||
|
|
||||||
[[Tracker]]
|
|
||||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
|
||||||
git-tree-sha1 = "1aa443d3b4bfa91a8aec32f169a479cb87309910"
|
|
||||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
|
||||||
version = "0.2.3"
|
|
||||||
|
|
||||||
[[TranscodingStreams]]
|
[[TranscodingStreams]]
|
||||||
deps = ["Random", "Test"]
|
deps = ["Random", "Test"]
|
||||||
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
|
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
|
||||||
@ -386,3 +385,17 @@ deps = ["BinaryProvider", "Libdl", "Printf"]
|
|||||||
git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d"
|
git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d"
|
||||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
version = "0.8.3"
|
version = "0.8.3"
|
||||||
|
|
||||||
|
[[Zygote]]
|
||||||
|
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||||
|
git-tree-sha1 = "9186cb0b3b59219e4aba0840614d6a9d7282012e"
|
||||||
|
repo-rev = "master"
|
||||||
|
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||||
|
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
|
version = "0.3.4"
|
||||||
|
|
||||||
|
[[ZygoteRules]]
|
||||||
|
deps = ["MacroTools"]
|
||||||
|
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6"
|
||||||
|
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||||
|
version = "0.1.0"
|
||||||
|
10
Project.toml
10
Project.toml
@ -21,18 +21,20 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
|||||||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
|
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
|
|
||||||
[compat]
|
[compat]
|
||||||
CUDAapi = "1.1"
|
CUDAapi = "1.1"
|
||||||
CuArrays = "1.2"
|
CuArrays = "1.2"
|
||||||
NNlib = "0.6"
|
NNlib = "0.6"
|
||||||
Tracker = "0.2"
|
Zygote = "0.3"
|
||||||
julia = "0.7, 1"
|
julia = "1.1"
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
|
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
|
||||||
[targets]
|
[targets]
|
||||||
test = ["Test"]
|
test = ["Test", "Documenter"]
|
||||||
|
13
REQUIRE
13
REQUIRE
@ -1,13 +0,0 @@
|
|||||||
julia 1.0
|
|
||||||
Juno
|
|
||||||
MacroTools 0.3.3
|
|
||||||
NNlib
|
|
||||||
Requires
|
|
||||||
Adapt 0.4
|
|
||||||
CodecZlib
|
|
||||||
Colors
|
|
||||||
ZipFile
|
|
||||||
AbstractTrees
|
|
||||||
Reexport
|
|
||||||
StatsBase
|
|
||||||
Tracker
|
|
@ -1,205 +1,56 @@
|
|||||||
# This file is machine-generated - editing it directly is not advised
|
# This file is machine-generated - editing it directly is not advised
|
||||||
|
|
||||||
[[AbstractTrees]]
|
|
||||||
deps = ["Markdown", "Test"]
|
|
||||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
|
||||||
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
|
||||||
version = "0.2.1"
|
|
||||||
|
|
||||||
[[Adapt]]
|
|
||||||
deps = ["LinearAlgebra", "Test"]
|
|
||||||
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
|
|
||||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
|
||||||
version = "0.4.2"
|
|
||||||
|
|
||||||
[[Base64]]
|
[[Base64]]
|
||||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||||
|
|
||||||
[[BinDeps]]
|
|
||||||
deps = ["Compat", "Libdl", "SHA", "URIParser"]
|
|
||||||
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
|
|
||||||
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
|
|
||||||
version = "0.8.10"
|
|
||||||
|
|
||||||
[[BinaryProvider]]
|
|
||||||
deps = ["Libdl", "Pkg", "SHA", "Test"]
|
|
||||||
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
|
||||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
|
||||||
version = "0.5.3"
|
|
||||||
|
|
||||||
[[CSTParser]]
|
|
||||||
deps = ["LibGit2", "Test", "Tokenize"]
|
|
||||||
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
|
|
||||||
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
|
||||||
version = "0.5.2"
|
|
||||||
|
|
||||||
[[CodecZlib]]
|
|
||||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
|
||||||
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
|
||||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
|
||||||
version = "0.5.2"
|
|
||||||
|
|
||||||
[[ColorTypes]]
|
|
||||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
|
||||||
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
|
|
||||||
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
|
||||||
version = "0.7.5"
|
|
||||||
|
|
||||||
[[Colors]]
|
|
||||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
|
|
||||||
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
|
|
||||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
|
||||||
version = "0.9.5"
|
|
||||||
|
|
||||||
[[CommonSubexpressions]]
|
|
||||||
deps = ["Test"]
|
|
||||||
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
|
|
||||||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
|
|
||||||
version = "0.2.0"
|
|
||||||
|
|
||||||
[[Compat]]
|
|
||||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
|
||||||
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
|
||||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
|
||||||
version = "2.1.0"
|
|
||||||
|
|
||||||
[[Crayons]]
|
|
||||||
deps = ["Test"]
|
|
||||||
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
|
||||||
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
|
||||||
version = "4.0.0"
|
|
||||||
|
|
||||||
[[DataStructures]]
|
|
||||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
|
||||||
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
|
|
||||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
|
||||||
version = "0.15.0"
|
|
||||||
|
|
||||||
[[Dates]]
|
[[Dates]]
|
||||||
deps = ["Printf"]
|
deps = ["Printf"]
|
||||||
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||||
|
|
||||||
[[DelimitedFiles]]
|
|
||||||
deps = ["Mmap"]
|
|
||||||
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
|
||||||
|
|
||||||
[[DiffResults]]
|
|
||||||
deps = ["Compat", "StaticArrays"]
|
|
||||||
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
|
|
||||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
|
||||||
version = "0.0.4"
|
|
||||||
|
|
||||||
[[DiffRules]]
|
|
||||||
deps = ["Random", "Test"]
|
|
||||||
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
|
||||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
|
||||||
version = "0.0.10"
|
|
||||||
|
|
||||||
[[Distributed]]
|
[[Distributed]]
|
||||||
deps = ["Random", "Serialization", "Sockets"]
|
deps = ["Random", "Serialization", "Sockets"]
|
||||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||||
|
|
||||||
[[DocStringExtensions]]
|
[[DocStringExtensions]]
|
||||||
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
|
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
|
||||||
git-tree-sha1 = "4d30e889c9f106a51ffa4791a88ffd4765bf20c3"
|
git-tree-sha1 = "0513f1a8991e9d83255e0140aace0d0fc4486600"
|
||||||
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
|
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
|
||||||
version = "0.7.0"
|
version = "0.8.0"
|
||||||
|
|
||||||
[[Documenter]]
|
[[Documenter]]
|
||||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
|
||||||
git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0"
|
git-tree-sha1 = "c61d6eedbc3c4323c08b64af12d29c8ee0fcbb5f"
|
||||||
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
version = "0.22.3"
|
version = "0.23.2"
|
||||||
|
|
||||||
[[FixedPointNumbers]]
|
|
||||||
deps = ["Test"]
|
|
||||||
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
|
|
||||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
|
||||||
version = "0.5.3"
|
|
||||||
|
|
||||||
[[Flux]]
|
|
||||||
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"]
|
|
||||||
path = ".."
|
|
||||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
|
||||||
version = "0.8.2+"
|
|
||||||
|
|
||||||
[[ForwardDiff]]
|
|
||||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
|
||||||
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
|
||||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
|
||||||
version = "0.10.3"
|
|
||||||
|
|
||||||
[[InteractiveUtils]]
|
[[InteractiveUtils]]
|
||||||
deps = ["Markdown"]
|
deps = ["Markdown"]
|
||||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
|
|
||||||
[[JSON]]
|
[[JSON]]
|
||||||
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
|
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
|
||||||
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
|
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
|
||||||
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
|
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
|
||||||
version = "0.20.0"
|
version = "0.21.0"
|
||||||
|
|
||||||
[[Juno]]
|
|
||||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
|
||||||
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
|
|
||||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
|
||||||
version = "0.7.0"
|
|
||||||
|
|
||||||
[[LibGit2]]
|
[[LibGit2]]
|
||||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||||
|
|
||||||
[[Libdl]]
|
|
||||||
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
|
|
||||||
|
|
||||||
[[LinearAlgebra]]
|
|
||||||
deps = ["Libdl"]
|
|
||||||
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
|
||||||
|
|
||||||
[[Logging]]
|
[[Logging]]
|
||||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||||
|
|
||||||
[[MacroTools]]
|
|
||||||
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
|
|
||||||
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
|
|
||||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
|
||||||
version = "0.5.0"
|
|
||||||
|
|
||||||
[[Markdown]]
|
[[Markdown]]
|
||||||
deps = ["Base64"]
|
deps = ["Base64"]
|
||||||
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
||||||
|
|
||||||
[[Media]]
|
|
||||||
deps = ["MacroTools", "Test"]
|
|
||||||
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
|
|
||||||
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
|
||||||
version = "0.5.0"
|
|
||||||
|
|
||||||
[[Missings]]
|
|
||||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
|
||||||
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
|
|
||||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
|
||||||
version = "0.4.0"
|
|
||||||
|
|
||||||
[[Mmap]]
|
[[Mmap]]
|
||||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||||
|
|
||||||
[[NNlib]]
|
[[Parsers]]
|
||||||
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
|
deps = ["Dates", "Test"]
|
||||||
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
|
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
|
||||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
|
||||||
version = "0.6.0"
|
version = "0.3.6"
|
||||||
|
|
||||||
[[NaNMath]]
|
|
||||||
deps = ["Compat"]
|
|
||||||
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
|
|
||||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
|
||||||
version = "0.3.2"
|
|
||||||
|
|
||||||
[[OrderedCollections]]
|
|
||||||
deps = ["Random", "Serialization", "Test"]
|
|
||||||
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
|
|
||||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
|
||||||
version = "1.1.0"
|
|
||||||
|
|
||||||
[[Pkg]]
|
[[Pkg]]
|
||||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||||
@ -209,10 +60,6 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
|||||||
deps = ["Unicode"]
|
deps = ["Unicode"]
|
||||||
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||||
|
|
||||||
[[Profile]]
|
|
||||||
deps = ["Printf"]
|
|
||||||
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
|
|
||||||
|
|
||||||
[[REPL]]
|
[[REPL]]
|
||||||
deps = ["InteractiveUtils", "Markdown", "Sockets"]
|
deps = ["InteractiveUtils", "Markdown", "Sockets"]
|
||||||
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
||||||
@ -221,106 +68,22 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
|||||||
deps = ["Serialization"]
|
deps = ["Serialization"]
|
||||||
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||||
|
|
||||||
[[Reexport]]
|
|
||||||
deps = ["Pkg"]
|
|
||||||
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
|
|
||||||
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
|
|
||||||
version = "0.2.0"
|
|
||||||
|
|
||||||
[[Requires]]
|
|
||||||
deps = ["Test"]
|
|
||||||
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
|
|
||||||
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
|
|
||||||
version = "0.5.2"
|
|
||||||
|
|
||||||
[[SHA]]
|
[[SHA]]
|
||||||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||||
|
|
||||||
[[Serialization]]
|
[[Serialization]]
|
||||||
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
||||||
|
|
||||||
[[SharedArrays]]
|
|
||||||
deps = ["Distributed", "Mmap", "Random", "Serialization"]
|
|
||||||
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
|
|
||||||
|
|
||||||
[[Sockets]]
|
[[Sockets]]
|
||||||
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
|
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
|
||||||
|
|
||||||
[[SortingAlgorithms]]
|
|
||||||
deps = ["DataStructures", "Random", "Test"]
|
|
||||||
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
|
|
||||||
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
|
|
||||||
version = "0.3.1"
|
|
||||||
|
|
||||||
[[SparseArrays]]
|
|
||||||
deps = ["LinearAlgebra", "Random"]
|
|
||||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
|
||||||
|
|
||||||
[[SpecialFunctions]]
|
|
||||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
|
||||||
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
|
|
||||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
|
||||||
version = "0.7.2"
|
|
||||||
|
|
||||||
[[StaticArrays]]
|
|
||||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
|
||||||
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
|
|
||||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
|
||||||
version = "0.10.3"
|
|
||||||
|
|
||||||
[[Statistics]]
|
|
||||||
deps = ["LinearAlgebra", "SparseArrays"]
|
|
||||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
|
||||||
|
|
||||||
[[StatsBase]]
|
|
||||||
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
|
||||||
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
|
|
||||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
|
||||||
version = "0.30.0"
|
|
||||||
|
|
||||||
[[Test]]
|
[[Test]]
|
||||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
|
||||||
[[TimerOutputs]]
|
|
||||||
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
|
||||||
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
|
||||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
|
||||||
version = "0.5.0"
|
|
||||||
|
|
||||||
[[Tokenize]]
|
|
||||||
deps = ["Printf", "Test"]
|
|
||||||
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
|
|
||||||
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
|
||||||
version = "0.5.3"
|
|
||||||
|
|
||||||
[[Tracker]]
|
|
||||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
|
||||||
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
|
|
||||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
|
||||||
version = "0.2.0"
|
|
||||||
|
|
||||||
[[TranscodingStreams]]
|
|
||||||
deps = ["Random", "Test"]
|
|
||||||
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
|
|
||||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
|
||||||
version = "0.9.4"
|
|
||||||
|
|
||||||
[[URIParser]]
|
|
||||||
deps = ["Test", "Unicode"]
|
|
||||||
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
|
|
||||||
uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
|
||||||
version = "0.4.0"
|
|
||||||
|
|
||||||
[[UUIDs]]
|
[[UUIDs]]
|
||||||
deps = ["Random", "SHA"]
|
deps = ["Random", "SHA"]
|
||||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||||
|
|
||||||
[[Unicode]]
|
[[Unicode]]
|
||||||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||||
|
|
||||||
[[ZipFile]]
|
|
||||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
|
||||||
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
|
||||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
|
||||||
version = "0.8.1"
|
|
||||||
|
@ -1,4 +1,2 @@
|
|||||||
[deps]
|
[deps]
|
||||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
|
||||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
|
||||||
|
18
docs/make.jl
18
docs/make.jl
@ -1,12 +1,13 @@
|
|||||||
|
using Pkg;
|
||||||
|
Pkg.activate(joinpath(@__DIR__, "..")); Pkg.instantiate()
|
||||||
|
Pkg.activate(); Pkg.instantiate()
|
||||||
|
|
||||||
|
pushfirst!(LOAD_PATH, joinpath(@__DIR__, ".."))
|
||||||
|
|
||||||
using Documenter, Flux, NNlib
|
using Documenter, Flux, NNlib
|
||||||
|
|
||||||
makedocs(modules=[Flux, NNlib],
|
makedocs(modules=[Flux, NNlib],
|
||||||
doctest = true,
|
|
||||||
analytics = "UA-36890222-9",
|
|
||||||
sitename = "Flux",
|
sitename = "Flux",
|
||||||
# Uncomment below for local build
|
|
||||||
#format = Documenter.HTML(prettyurls = false),
|
|
||||||
assets = ["assets/flux.css"],
|
|
||||||
pages = ["Home" => "index.md",
|
pages = ["Home" => "index.md",
|
||||||
"Building Models" =>
|
"Building Models" =>
|
||||||
["Basics" => "models/basics.md",
|
["Basics" => "models/basics.md",
|
||||||
@ -20,8 +21,9 @@ makedocs(modules=[Flux, NNlib],
|
|||||||
"GPU Support" => "gpu.md",
|
"GPU Support" => "gpu.md",
|
||||||
"Saving & Loading" => "saving.md",
|
"Saving & Loading" => "saving.md",
|
||||||
"Performance Tips" => "performance.md",
|
"Performance Tips" => "performance.md",
|
||||||
"Internals" =>
|
"Community" => "community.md"],
|
||||||
["Backpropagation" => "internals/tracker.md"],
|
format = Documenter.HTML(assets = ["assets/flux.css"],
|
||||||
"Community" => "community.md"])
|
analytics = "UA-36890222-9",
|
||||||
|
prettyurls = haskey(ENV, "CI")))
|
||||||
|
|
||||||
deploydocs(repo = "github.com/FluxML/Flux.jl.git")
|
deploydocs(repo = "github.com/FluxML/Flux.jl.git")
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Community
|
# Community
|
||||||
|
|
||||||
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby). If you have questions or issues we'll try to help you out.
|
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), or the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning). If you have questions or issues we'll try to help you out.
|
||||||
|
|
||||||
If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started.
|
If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started.
|
||||||
|
@ -1,14 +1,6 @@
|
|||||||
# GPU Support
|
# GPU Support
|
||||||
|
|
||||||
## Installation
|
NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CuArrays](https://github.com/JuliaGPU/CuArrays.jl) readme.
|
||||||
|
|
||||||
To get GPU support for NVIDIA graphics cards, you need to install `CuArrays.jl`
|
|
||||||
|
|
||||||
**Steps needed**
|
|
||||||
|
|
||||||
1. Install [NVIDIA toolkit](https://developer.nvidia.com/cuda-downloads)
|
|
||||||
2. Install [NVIDIA cuDNN library](https://developer.nvidia.com/cudnn)
|
|
||||||
3. In Julia's terminal run `]add CuArrays`
|
|
||||||
|
|
||||||
## GPU Usage
|
## GPU Usage
|
||||||
|
|
||||||
|
@ -1,184 +0,0 @@
|
|||||||
# Flux.Tracker
|
|
||||||
|
|
||||||
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> using Flux.Tracker
|
|
||||||
```
|
|
||||||
|
|
||||||
Here we discuss some more advanced uses of this module, as well as covering its internals.
|
|
||||||
|
|
||||||
## Taking Gradients
|
|
||||||
|
|
||||||
In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
using Flux.Tracker
|
|
||||||
|
|
||||||
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
|
|
||||||
```
|
|
||||||
|
|
||||||
`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
using Flux.Tracker: forward
|
|
||||||
|
|
||||||
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
|
|
||||||
|
|
||||||
back(1) # (3.0 (tracked), 2.0 (tracked))
|
|
||||||
```
|
|
||||||
|
|
||||||
The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
|
|
||||||
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
|
|
||||||
|
|
||||||
julia> back([1,1,1])
|
|
||||||
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
|
|
||||||
```
|
|
||||||
|
|
||||||
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
a, b = param(2), param(3)
|
|
||||||
|
|
||||||
c = a*b # 6.0 (tracked)
|
|
||||||
|
|
||||||
Tracker.back!(c)
|
|
||||||
|
|
||||||
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Tracked Arrays
|
|
||||||
|
|
||||||
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> W = param([1 2; 3 4])
|
|
||||||
Tracked 2×2 Array{Float64,2}:
|
|
||||||
1.0 2.0
|
|
||||||
3.0 4.0
|
|
||||||
|
|
||||||
julia> x = param([5, 6])
|
|
||||||
Tracked 2-element Array{Float64,1}:
|
|
||||||
5.0
|
|
||||||
6.0
|
|
||||||
|
|
||||||
julia> y = W*x
|
|
||||||
Tracked 2-element Array{Float64,1}:
|
|
||||||
17.0
|
|
||||||
39.0
|
|
||||||
```
|
|
||||||
|
|
||||||
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> Tracker.back!(y, [1, -1])
|
|
||||||
|
|
||||||
julia> W.grad
|
|
||||||
2×2 Array{Float64,2}:
|
|
||||||
5.0 6.0
|
|
||||||
-5.0 -6.0
|
|
||||||
|
|
||||||
julia> x.grad
|
|
||||||
2-element Array{Float64,1}:
|
|
||||||
-2.0
|
|
||||||
-2.0
|
|
||||||
```
|
|
||||||
|
|
||||||
You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`.
|
|
||||||
|
|
||||||
## Custom Gradients
|
|
||||||
|
|
||||||
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
minus(a, b) = a - b
|
|
||||||
```
|
|
||||||
|
|
||||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
using Flux.Tracker: TrackedArray, track, @grad
|
|
||||||
|
|
||||||
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
|
|
||||||
```
|
|
||||||
|
|
||||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
@grad function minus(a, b)
|
|
||||||
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
|
|
||||||
end
|
|
||||||
```
|
|
||||||
|
|
||||||
This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress).
|
|
||||||
|
|
||||||
Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
|
||||||
```
|
|
||||||
|
|
||||||
We can then calculate the first derivative of `minus` as follows:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
a = param([1,2,3])
|
|
||||||
b = param([3,2,1])
|
|
||||||
|
|
||||||
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
|
|
||||||
|
|
||||||
Tracker.back!(c, 1)
|
|
||||||
Tracker.grad(a) # [1.00, 1.00, 1.00]
|
|
||||||
Tracker.grad(b) # [-1.00, -1.00, -1.00]
|
|
||||||
```
|
|
||||||
|
|
||||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
|
||||||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Tracked Internals
|
|
||||||
|
|
||||||
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> x.tracker
|
|
||||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
|
||||||
```
|
|
||||||
|
|
||||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> x.tracker.grad
|
|
||||||
2-element Array{Float64,1}:
|
|
||||||
-2.0
|
|
||||||
-2.0
|
|
||||||
```
|
|
||||||
|
|
||||||
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> Tracker.Call(+, 1, 2)
|
|
||||||
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
|
|
||||||
```
|
|
||||||
|
|
||||||
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
julia> y.tracker.f
|
|
||||||
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
|
|
||||||
```
|
|
||||||
|
|
||||||
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
|
|
||||||
|
|
||||||
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
|
|
||||||
|
|
||||||
```julia
|
|
||||||
Tracker.back(*, [1, -1], W, x)
|
|
||||||
```
|
|
||||||
|
|
||||||
which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
|
|
@ -5,55 +5,56 @@
|
|||||||
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
||||||
|
|
||||||
```jldoctest basics
|
```jldoctest basics
|
||||||
julia> using Flux.Tracker
|
julia> using Flux
|
||||||
|
|
||||||
julia> f(x) = 3x^2 + 2x + 1;
|
julia> f(x) = 3x^2 + 2x + 1;
|
||||||
|
|
||||||
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
|
julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
|
||||||
|
|
||||||
julia> df(2)
|
julia> df(2)
|
||||||
14.0 (tracked)
|
14
|
||||||
|
|
||||||
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
|
julia> d2f(x) = gradient(df, x)[1]; # d²f/dx² = 6
|
||||||
|
|
||||||
julia> d2f(2)
|
julia> d2f(2)
|
||||||
6.0 (tracked)
|
6
|
||||||
```
|
```
|
||||||
|
|
||||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
When a function has many parameters, we can get gradients of each one at the same time:
|
||||||
|
|
||||||
When a function has many parameters, we can pass them all in explicitly:
|
|
||||||
|
|
||||||
```jldoctest basics
|
```jldoctest basics
|
||||||
julia> f(W, b, x) = W * x + b;
|
julia> f(x, y) = sum((x .- y).^2);
|
||||||
|
|
||||||
julia> Tracker.gradient(f, 2, 3, 4)
|
julia> gradient(f, [2, 1], [2, 0])
|
||||||
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
([0, 2], [0, -2])
|
||||||
```
|
```
|
||||||
|
|
||||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
But machine learning models can have *hundreds* of parameters! To handle this, Flux lets you work with collections of parameters, via `params`. You can get the gradient of all parameters used in a program without explicitly passing them in.
|
||||||
|
|
||||||
```jldoctest basics
|
```jldoctest basics
|
||||||
julia> using Flux
|
julia> using Flux
|
||||||
|
|
||||||
julia> W = param(2)
|
julia> x = [2, 1];
|
||||||
2.0 (tracked)
|
|
||||||
|
|
||||||
julia> b = param(3)
|
julia> y = [2, 0];
|
||||||
3.0 (tracked)
|
|
||||||
|
|
||||||
julia> f(x) = W * x + b;
|
julia> gs = gradient(params(x, y)) do
|
||||||
|
f(x, y)
|
||||||
|
end
|
||||||
|
Grads(...)
|
||||||
|
|
||||||
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
|
julia> gs[x]
|
||||||
|
2-element Array{Int64,1}:
|
||||||
|
0
|
||||||
|
2
|
||||||
|
|
||||||
julia> grads[W]
|
julia> gs[y]
|
||||||
4.0 (tracked)
|
2-element Array{Int64,1}:
|
||||||
|
0
|
||||||
julia> grads[b]
|
-2
|
||||||
1.0 (tracked)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||||
|
|
||||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||||
|
|
||||||
@ -76,26 +77,20 @@ x, y = rand(5), rand(2) # Dummy data
|
|||||||
loss(x, y) # ~ 3
|
loss(x, y) # ~ 3
|
||||||
```
|
```
|
||||||
|
|
||||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above.
|
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker
|
using Flux
|
||||||
|
|
||||||
W = param(W)
|
gs = gradient(() -> loss(x, y), params(W, b))
|
||||||
b = param(b)
|
|
||||||
|
|
||||||
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
Now that we have gradients, we can pull them out and update `W` to train the model.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker: update!
|
W̄ = gs[W]
|
||||||
|
|
||||||
Δ = gs[W]
|
W .-= 0.1 .* W̄
|
||||||
|
|
||||||
# Update the parameter and reset the gradient
|
|
||||||
update!(W, -0.1Δ)
|
|
||||||
|
|
||||||
loss(x, y) # ~ 2.5
|
loss(x, y) # ~ 2.5
|
||||||
```
|
```
|
||||||
@ -111,12 +106,12 @@ It's common to create more complex models than the linear regression above. For
|
|||||||
```julia
|
```julia
|
||||||
using Flux
|
using Flux
|
||||||
|
|
||||||
W1 = param(rand(3, 5))
|
W1 = rand(3, 5)
|
||||||
b1 = param(rand(3))
|
b1 = rand(3)
|
||||||
layer1(x) = W1 * x .+ b1
|
layer1(x) = W1 * x .+ b1
|
||||||
|
|
||||||
W2 = param(rand(2, 3))
|
W2 = rand(2, 3)
|
||||||
b2 = param(rand(2))
|
b2 = rand(2)
|
||||||
layer2(x) = W2 * x .+ b2
|
layer2(x) = W2 * x .+ b2
|
||||||
|
|
||||||
model(x) = layer2(σ.(layer1(x)))
|
model(x) = layer2(σ.(layer1(x)))
|
||||||
@ -128,8 +123,8 @@ This works but is fairly unwieldy, with a lot of repetition – especially as we
|
|||||||
|
|
||||||
```julia
|
```julia
|
||||||
function linear(in, out)
|
function linear(in, out)
|
||||||
W = param(randn(out, in))
|
W = randn(out, in)
|
||||||
b = param(randn(out))
|
b = randn(out)
|
||||||
x -> W * x .+ b
|
x -> W * x .+ b
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -150,7 +145,7 @@ struct Affine
|
|||||||
end
|
end
|
||||||
|
|
||||||
Affine(in::Integer, out::Integer) =
|
Affine(in::Integer, out::Integer) =
|
||||||
Affine(param(randn(out, in)), param(randn(out)))
|
Affine(randn(out, in), randn(out))
|
||||||
|
|
||||||
# Overload call, so the object can be used as a function
|
# Overload call, so the object can be used as a function
|
||||||
(m::Affine)(x) = m.W * x .+ m.b
|
(m::Affine)(x) = m.W * x .+ m.b
|
||||||
|
@ -59,7 +59,6 @@ swish
|
|||||||
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
Flux.testmode!
|
|
||||||
BatchNorm
|
BatchNorm
|
||||||
Dropout
|
Dropout
|
||||||
AlphaDropout
|
AlphaDropout
|
||||||
|
@ -101,26 +101,4 @@ m = Chain(LSTM(10, 15), Dense(15, 5))
|
|||||||
m.(seq)
|
m.(seq)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Truncating Gradients
|
Finally, we can reset the hidden state of the cell back to its initial value using `reset!(m)`.
|
||||||
|
|
||||||
By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.
|
|
||||||
|
|
||||||
To avoid this we can *truncate* the gradient calculation, forgetting the history.
|
|
||||||
|
|
||||||
```julia
|
|
||||||
truncate!(m)
|
|
||||||
```
|
|
||||||
|
|
||||||
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
|
||||||
|
|
||||||
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
|
||||||
|
|
||||||
In general, when training with recurrent layers in your model, you'll want to call `reset!` or `truncate!` for each loss calculation:
|
|
||||||
|
|
||||||
```julia
|
|
||||||
function loss(x,y)
|
|
||||||
l = Flux.mse(m(x), y)
|
|
||||||
Flux.reset!(m)
|
|
||||||
return l
|
|
||||||
end
|
|
||||||
```
|
|
||||||
|
@ -15,6 +15,8 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
|
|||||||
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
|
using LinearAlgebra
|
||||||
|
|
||||||
penalty() = norm(m.W) + norm(m.b)
|
penalty() = norm(m.W) + norm(m.b)
|
||||||
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
||||||
```
|
```
|
||||||
@ -48,15 +50,17 @@ loss(rand(28^2), rand(10))
|
|||||||
One can also easily add per-layer regularisation via the `activations` function:
|
One can also easily add per-layer regularisation via the `activations` function:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
|
julia> using Flux: activations
|
||||||
|
|
||||||
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||||
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
|
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
|
|
||||||
julia> activations(c, rand(10))
|
julia> activations(c, rand(10))
|
||||||
3-element Array{Any,1}:
|
3-element Array{Any,1}:
|
||||||
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
|
Float32[0.84682214, 0.6704139, 0.42177814, 0.257832, 0.36255655]
|
||||||
param([0.0330606, -0.456104])
|
Float32[0.1501253, 0.073269576]
|
||||||
param([0.61991, 0.38009])
|
Float32[0.5192045, 0.48079553]
|
||||||
|
|
||||||
julia> sum(norm, ans)
|
julia> sum(norm, ans)
|
||||||
2.639678767773633 (tracked)
|
2.1166067f0
|
||||||
```
|
```
|
||||||
|
@ -53,7 +53,7 @@ julia> using Flux
|
|||||||
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
|
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
|
||||||
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
|
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
|
||||||
|
|
||||||
julia> weights = Tracker.data.(params(model));
|
julia> weights = params(model);
|
||||||
|
|
||||||
julia> using BSON: @save
|
julia> using BSON: @save
|
||||||
|
|
||||||
|
@ -3,25 +3,25 @@
|
|||||||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux, Flux.Tracker
|
using Flux
|
||||||
|
|
||||||
W = param(rand(2, 5))
|
W = rand(2, 5))
|
||||||
b = param(rand(2))
|
b = rand(2)
|
||||||
|
|
||||||
predict(x) = W*x .+ b
|
predict(x) = (W * x) .+ b
|
||||||
loss(x, y) = sum((predict(x) .- y).^2)
|
loss(x, y) = sum((predict(x) .- y).^2)
|
||||||
|
|
||||||
x, y = rand(5), rand(2) # Dummy data
|
x, y = rand(5), rand(2) # Dummy data
|
||||||
l = loss(x, y) # ~ 3
|
l = loss(x, y) # ~ 3
|
||||||
|
|
||||||
θ = Params([W, b])
|
θ = Params([W, b])
|
||||||
grads = Tracker.gradient(() -> loss(x, y), θ)
|
grads = gradient(() -> loss(x, y), θ)
|
||||||
```
|
```
|
||||||
|
|
||||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker: grad, update!
|
using Flux: update!
|
||||||
|
|
||||||
η = 0.1 # Learning Rate
|
η = 0.1 # Learning Rate
|
||||||
for p in (W, b)
|
for p in (W, b)
|
||||||
|
16
src/Flux.jl
16
src/Flux.jl
@ -3,19 +3,15 @@ module Flux
|
|||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
using Base: tail
|
using Base: tail
|
||||||
using MacroTools, Juno, Reexport, Statistics, Random
|
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
@reexport using NNlib
|
||||||
|
using Zygote: Params, @adjoint, gradient, forward
|
||||||
|
export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
SkipConnection,
|
SkipConnection, params, mapleaves, cpu, gpu, f32, f64
|
||||||
params, mapleaves, cpu, gpu, f32, f64
|
|
||||||
|
|
||||||
@reexport using NNlib
|
|
||||||
|
|
||||||
using Tracker
|
|
||||||
using Tracker: data
|
|
||||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
@ -49,6 +45,8 @@ include("layers/normalise.jl")
|
|||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
|
include("deprecations.jl")
|
||||||
|
|
||||||
if has_cuarrays()
|
if has_cuarrays()
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
end
|
end
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
using CuArrays: libcudnn
|
using CuArrays: libcudnn
|
||||||
using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t,
|
using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
|
||||||
import CuArrays.CUDAdrv: CuPtr, CU_NULL
|
import CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||||
|
|
||||||
using LinearAlgebra
|
using LinearAlgebra
|
||||||
|
|
||||||
import ..Flux: data
|
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Nothing}
|
ptr::Ptr{Nothing}
|
||||||
states::CuVector{UInt8}
|
states::CuVector{UInt8}
|
||||||
@ -198,36 +195,8 @@ end
|
|||||||
|
|
||||||
# Flux Interface
|
# Flux Interface
|
||||||
|
|
||||||
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing)
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
|
||||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
|
||||||
|
@ -225,7 +225,6 @@ end
|
|||||||
# Interface
|
# Interface
|
||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Tracker: TrackedArray
|
|
||||||
using CuArrays.CUDAnative
|
using CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
@ -240,17 +239,16 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
|||||||
return dst
|
return dst
|
||||||
end
|
end
|
||||||
|
|
||||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
|
||||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||||
|
|
||||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||||
Wi, Wh = d.weights
|
Wi, Wh = d.weights
|
||||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
copy_transpose!(Wi, m.Wi)
|
||||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
copy_transpose!(Wh, m.Wh)
|
||||||
copy_transpose!(d.bias, Flux.data(m.b))
|
copy_transpose!(d.bias, m.b)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -271,59 +269,58 @@ function desc(rnn)
|
|||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
|
|
||||||
import Flux.Tracker
|
using ..Flux: @adjoint
|
||||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
|
||||||
|
|
||||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
|
y, h′ = forward(desc(m), x, h)
|
||||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
return h′, y
|
||||||
result = istrain(m, h, x) ?
|
|
||||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
|
||||||
forward(desc(m), x, h)
|
|
||||||
return result[2], result[1]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = istrain(m, h, x) ?
|
y, h′ = forward(desc(m), x, h)
|
||||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
return h′, y
|
||||||
forward(desc(m), x, h)
|
|
||||||
return result[2], result[1]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = istrain(m, h, x) ?
|
y, h′, c′ = forward(desc(m), x, h[1], h[2])
|
||||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
return (h′, c′), y
|
||||||
forward(desc(m), x, h[1], h[2])
|
|
||||||
return (result[2], result[3]), result[1]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
|
||||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
|
||||||
result, function (Δ)
|
unbroadcast(x::AbstractArray, Δ) =
|
||||||
y, ho = result
|
size(x) == size(Δ) ? Δ :
|
||||||
dy, dho = Δ
|
length(x) == length(Δ) ? trim(x, Δ) :
|
||||||
h_ = hBatch(x, data(h))
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
for RNN in (CuRNN, CuGRU)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
|
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
||||||
|
(ho, y), function (Δ)
|
||||||
|
dho, dy = Δ
|
||||||
|
h_ = hBatch(x, h)
|
||||||
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
|
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
||||||
|
(dm, unbroadcast(h, dh), dx)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
|
||||||
result, function (Δ)
|
((ho, co), y), function (Δ)
|
||||||
y, ho = result
|
dhc, dy = Δ
|
||||||
dy, dho, dco = Δ
|
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, h)
|
||||||
c_ = hBatch(x, data(c))
|
c_ = hBatch(x, c)
|
||||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
||||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
||||||
transpose(dWi), transpose(dWh), db))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Iris
|
|
||||||
|
|
||||||
Fisher's classic iris dataset.
|
Fisher's classic iris dataset.
|
||||||
|
|
||||||
Measurements from 3 different species of iris: setosa, versicolor and
|
Measurements from 3 different species of iris: setosa, versicolor and
|
||||||
virginica. There are 50 examples of each species.
|
virginica. There are 50 examples of each species.
|
||||||
|
|
||||||
There are 4 measurements for each example: sepal length, sepal width, petal
|
There are 4 measurements for each example: sepal length, sepal width, petal
|
||||||
length and petal width. The measurements are in centimeters.
|
length and petal width. The measurements are in centimeters.
|
||||||
|
|
||||||
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
||||||
@ -35,10 +31,12 @@ end
|
|||||||
|
|
||||||
labels()
|
labels()
|
||||||
|
|
||||||
Get the labels of the iris dataset, a 150 element array of strings listing the
|
Get the labels of the iris dataset, a 150 element array of strings listing the
|
||||||
species of each example.
|
species of each example.
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux
|
||||||
|
|
||||||
julia> labels = Flux.Data.Iris.labels();
|
julia> labels = Flux.Data.Iris.labels();
|
||||||
|
|
||||||
julia> summary(labels)
|
julia> summary(labels)
|
||||||
@ -58,11 +56,13 @@ end
|
|||||||
|
|
||||||
features()
|
features()
|
||||||
|
|
||||||
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
||||||
elements. It has a row for each feature (sepal length, sepal width,
|
elements. It has a row for each feature (sepal length, sepal width,
|
||||||
petal length, petal width) and a column for each example.
|
petal length, petal width) and a column for each example.
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux
|
||||||
|
|
||||||
julia> features = Flux.Data.Iris.features();
|
julia> features = Flux.Data.Iris.features();
|
||||||
|
|
||||||
julia> summary(features)
|
julia> summary(features)
|
||||||
@ -81,6 +81,5 @@ function features()
|
|||||||
iris = readdlm(deps("iris.data"), ',')
|
iris = readdlm(deps("iris.data"), ',')
|
||||||
Matrix{Float64}(iris[1:end, 1:4]')
|
Matrix{Float64}(iris[1:end, 1:4]')
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
2
src/deprecations.jl
Normal file
2
src/deprecations.jl
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
@deprecate param(x) x
|
||||||
|
@deprecate data(x) x
|
@ -89,7 +89,7 @@ Dense(W, b) = Dense(W, b, identity)
|
|||||||
|
|
||||||
function Dense(in::Integer, out::Integer, σ = identity;
|
function Dense(in::Integer, out::Integer, σ = identity;
|
||||||
initW = glorot_uniform, initb = zeros)
|
initW = glorot_uniform, initb = zeros)
|
||||||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
return Dense(initW(out, in), initb(out), σ)
|
||||||
end
|
end
|
||||||
|
|
||||||
@treelike Dense
|
@treelike Dense
|
||||||
@ -129,7 +129,7 @@ struct Diagonal{T}
|
|||||||
end
|
end
|
||||||
|
|
||||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||||
Diagonal(param(initα(in)), param(initβ(in)))
|
Diagonal(initα(in), initβ(in))
|
||||||
|
|
||||||
@treelike Diagonal
|
@treelike Diagonal
|
||||||
|
|
||||||
@ -204,7 +204,6 @@ A 'ResNet'-type skip-connection with identity shortcut would simply be
|
|||||||
SkipConnection(layer, (a,b) -> a + b)
|
SkipConnection(layer, (a,b) -> a + b)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
struct SkipConnection
|
struct SkipConnection
|
||||||
layers
|
layers
|
||||||
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
||||||
|
@ -14,11 +14,11 @@ Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
|
|||||||
|
|
||||||
size = (2,2)
|
size = (2,2)
|
||||||
in = 1
|
in = 1
|
||||||
out = 16
|
out = 16
|
||||||
Conv((2, 2), 1=>16, relu)
|
Conv((2, 2), 1=>16, relu)
|
||||||
|
|
||||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
and a batch of 50 would be a `100×100×3×50` array.
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
@ -42,7 +42,7 @@ end
|
|||||||
|
|
||||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
Conv(init(k..., ch...), zeros(ch[2]), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
@treelike Conv
|
@treelike Conv
|
||||||
@ -74,8 +74,10 @@ end
|
|||||||
|
|
||||||
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
struct ConvTranspose{N,M,F,A,V}
|
struct ConvTranspose{N,M,F,A,V}
|
||||||
@ -97,7 +99,7 @@ end
|
|||||||
|
|
||||||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||||
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
@treelike ConvTranspose
|
@treelike ConvTranspose
|
||||||
@ -169,8 +171,8 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
|
|||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N
|
||||||
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
|
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
|
||||||
return DepthwiseConv(
|
return DepthwiseConv(
|
||||||
param(init(k..., div(ch[2], ch[1]), ch[1])),
|
init(k..., div(ch[2], ch[1]), ch[1]),
|
||||||
param(zeros(ch[2])),
|
zeros(ch[2]),
|
||||||
σ;
|
σ;
|
||||||
stride = stride,
|
stride = stride,
|
||||||
pad = pad,
|
pad = pad,
|
||||||
@ -198,25 +200,26 @@ end
|
|||||||
|
|
||||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
CrossCor(size, in=>out)
|
CrossCor(size, in=>out)
|
||||||
CrossCor(size, in=>out, relu)
|
CrossCor(size, in=>out, relu)
|
||||||
|
|
||||||
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
|
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
|
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
|
||||||
giving us a 16-channel output. Output is activated with ReLU.
|
giving us a 16-channel output. Output is activated with ReLU.
|
||||||
|
|
||||||
size = (2,2)
|
size = (2,2)
|
||||||
in = 1
|
in = 1
|
||||||
out = 16
|
out = 16
|
||||||
CrossCor((2, 2), 1=>16, relu)
|
CrossCor((2, 2), 1=>16, relu)
|
||||||
|
|
||||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
and a batch of 50 would be a `100×100×3×50` array.
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
struct CrossCor{N,M,F,A,V}
|
struct CrossCor{N,M,F,A,V}
|
||||||
@ -238,7 +241,7 @@ end
|
|||||||
|
|
||||||
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||||
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
@treelike CrossCor
|
@treelike CrossCor
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
"""
|
istraining() = false
|
||||||
testmode!(m)
|
|
||||||
testmode!(m, false)
|
|
||||||
|
|
||||||
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
|
@adjoint istraining() = true, _ -> nothing
|
||||||
(or back to training mode with `false`).
|
|
||||||
"""
|
_dropout_shape(s, ::Colon) = size(s)
|
||||||
function testmode!(m, val::Bool=true)
|
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||||
prefor(x -> _testmode!(x, val), m)
|
|
||||||
return m
|
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||||
|
|
||||||
|
dropout(x, p; dims = :) = x
|
||||||
|
|
||||||
|
@adjoint function dropout(x, p; dims = :)
|
||||||
|
y = rand!(similar(x, _dropout_shape(x, dims)))
|
||||||
|
y .= _dropout_kernel.(y, p, 1 - p)
|
||||||
|
return x .* y, Δ -> (Δ .* y, nothing)
|
||||||
end
|
end
|
||||||
|
|
||||||
_testmode!(m, test) = nothing
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dropout(p, dims = :)
|
Dropout(p, dims = :)
|
||||||
|
|
||||||
@ -19,79 +22,52 @@ A Dropout layer. For each input, either sets that input to `0` (with probability
|
|||||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
||||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||||
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
|
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
|
||||||
|
|
||||||
Does nothing to the input once in [`testmode!`](@ref).
|
|
||||||
"""
|
"""
|
||||||
mutable struct Dropout{F}
|
mutable struct Dropout{F,D}
|
||||||
p::F
|
p::F
|
||||||
dims::Union{Colon, Int, NTuple{N, Int} where N}
|
dims::D
|
||||||
active::Bool
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dropout(p; dims = :)
|
function Dropout(p; dims = :)
|
||||||
@assert 0 ≤ p ≤ 1
|
@assert 0 ≤ p ≤ 1
|
||||||
Dropout{typeof(p)}(p, dims, true)
|
Dropout{typeof(p),typeof(dims)}(p, dims)
|
||||||
end
|
end
|
||||||
|
|
||||||
_dropout_shape(s, ::Colon) = size(s)
|
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims)
|
||||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
|
||||||
|
|
||||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
function Base.show(io::IO, d::Dropout)
|
||||||
|
print(io, "Dropout(", d.p)
|
||||||
|
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
|
||||||
"""
|
print(io, ")")
|
||||||
dropout(x, p; dims = :)
|
|
||||||
|
|
||||||
The dropout function. For each input, either sets that input to `0` (with probability
|
|
||||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
|
||||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
|
||||||
used as a regularisation, i.e. it reduces overfitting during training.
|
|
||||||
"""
|
|
||||||
function dropout(x, p; dims = :)
|
|
||||||
y = similar(x, _dropout_shape(x, dims))
|
|
||||||
rand!(y)
|
|
||||||
y .= _dropout_kernel.(y, p, 1 - p)
|
|
||||||
return x .* y
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function (a::Dropout)(x)
|
|
||||||
a.active || return x
|
|
||||||
return dropout(x, a.p; dims = a.dims)
|
|
||||||
end
|
|
||||||
|
|
||||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
AlphaDropout(p)
|
AlphaDropout(p)
|
||||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||||
"""
|
"""
|
||||||
mutable struct AlphaDropout{F}
|
mutable struct AlphaDropout{F}
|
||||||
p::F
|
p::F
|
||||||
active::Bool
|
function AlphaDropout(p)
|
||||||
end
|
@assert 0 ≤ p ≤ 1
|
||||||
|
new{typeof(p)}(p)
|
||||||
function AlphaDropout(p)
|
end
|
||||||
@assert 0 ≤ p ≤ 1
|
|
||||||
AlphaDropout(p,true)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function (a::AlphaDropout)(x)
|
function (a::AlphaDropout)(x)
|
||||||
a.active || return x
|
istraining() || return x
|
||||||
λ = eltype(x)(1.0507009873554804934193349852946)
|
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||||
α = eltype(x)(1.6732632423543772848170429916717)
|
α = eltype(x)(1.6732632423543772848170429916717)
|
||||||
α1 = eltype(x)(-λ*α)
|
α1 = eltype(x)(-λ*α)
|
||||||
noise = randn(eltype(x), size(x))
|
noise = randn(eltype(x), size(x))
|
||||||
x = @. x*(noise > (1 - a.p)) + α1 * (noise <= (1 - a.p))
|
x = @. x*(noise > (1 - a.p)) + α1 * (noise < (1 - a.p))
|
||||||
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
|
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
|
||||||
B = -A * α1 * (1 - a.p)
|
B = -A * α1 * (1 - a.p)
|
||||||
x = @. A * x + B
|
x = @. A * x + B
|
||||||
return x
|
return x
|
||||||
end
|
end
|
||||||
|
|
||||||
_testmode!(a::AlphaDropout, test) = (a.active = !test)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
LayerNorm(h::Integer)
|
LayerNorm(h::Integer)
|
||||||
|
|
||||||
@ -151,25 +127,23 @@ mutable struct BatchNorm{F,V,W,N}
|
|||||||
σ²::W # moving std
|
σ²::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
active::Bool
|
|
||||||
end
|
end
|
||||||
|
|
||||||
BatchNorm(chs::Integer, λ = identity;
|
BatchNorm(chs::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum)
|
||||||
|
|
||||||
function (BN::BatchNorm)(x)
|
function (BN::BatchNorm)(x)
|
||||||
size(x, ndims(x)-1) == length(BN.β) ||
|
size(x, ndims(x)-1) == length(BN.β) ||
|
||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||||
affine_shape[end-1] = channels
|
m = div(prod(size(x)), channels)
|
||||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
|
||||||
γ = reshape(BN.γ, affine_shape...)
|
γ = reshape(BN.γ, affine_shape...)
|
||||||
β = reshape(BN.β, affine_shape...)
|
β = reshape(BN.β, affine_shape...)
|
||||||
if !BN.active
|
if !istraining()
|
||||||
μ = reshape(BN.μ, affine_shape...)
|
μ = reshape(BN.μ, affine_shape...)
|
||||||
σ² = reshape(BN.σ², affine_shape...)
|
σ² = reshape(BN.σ², affine_shape...)
|
||||||
ϵ = BN.ϵ
|
ϵ = BN.ϵ
|
||||||
@ -178,11 +152,12 @@ function (BN::BatchNorm)(x)
|
|||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, dims = axes)
|
μ = mean(x, dims = axes)
|
||||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
ϵ = convert(T, BN.ϵ)
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = BN.momentum
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
S = eltype(BN.μ)
|
||||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :))
|
||||||
|
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :))
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
@ -192,12 +167,10 @@ function (BN::BatchNorm)(x)
|
|||||||
end
|
end
|
||||||
|
|
||||||
children(BN::BatchNorm) =
|
children(BN::BatchNorm) =
|
||||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
|
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
|
||||||
|
|
||||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
|
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
|
||||||
|
|
||||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
|
||||||
|
|
||||||
function Base.show(io::IO, l::BatchNorm)
|
function Base.show(io::IO, l::BatchNorm)
|
||||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||||
@ -244,13 +217,12 @@ mutable struct InstanceNorm{F,V,W,N}
|
|||||||
σ²::W # moving std
|
σ²::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
active::Bool
|
|
||||||
end
|
end
|
||||||
|
|
||||||
InstanceNorm(chs::Integer, λ = identity;
|
InstanceNorm(chs::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum)
|
||||||
|
|
||||||
function (in::InstanceNorm)(x)
|
function (in::InstanceNorm)(x)
|
||||||
size(x, ndims(x)-1) == length(in.β) ||
|
size(x, ndims(x)-1) == length(in.β) ||
|
||||||
@ -261,28 +233,26 @@ function (in::InstanceNorm)(x)
|
|||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
c = size(x, dims-1)
|
c = size(x, dims-1)
|
||||||
bs = size(x, dims)
|
bs = size(x, dims)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x))
|
||||||
affine_shape[end-1] = c
|
m = div(prod(size(x)), c*bs)
|
||||||
affine_shape[end] = bs
|
|
||||||
m = prod(size(x)[1:end-2])
|
|
||||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||||
|
|
||||||
if !in.active
|
if !istraining()
|
||||||
μ = expand_inst(in.μ, affine_shape)
|
μ = expand_inst(in.μ, affine_shape)
|
||||||
σ² = expand_inst(in.σ², affine_shape)
|
σ² = expand_inst(in.σ², affine_shape)
|
||||||
ϵ = in.ϵ
|
ϵ = in.ϵ
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
ϵ = data(convert(T, in.ϵ))
|
ϵ = convert(T, in.ϵ)
|
||||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||||
μ = mean(x, dims = axes)
|
μ = mean(x, dims = axes)
|
||||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||||
|
S = eltype(in.μ)
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, in.momentum))
|
mtm = in.momentum
|
||||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2)
|
||||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = in.λ
|
let λ = in.λ
|
||||||
@ -292,12 +262,10 @@ function (in::InstanceNorm)(x)
|
|||||||
end
|
end
|
||||||
|
|
||||||
children(in::InstanceNorm) =
|
children(in::InstanceNorm) =
|
||||||
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
|
||||||
|
|
||||||
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||||
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
|
||||||
|
|
||||||
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
|
||||||
|
|
||||||
function Base.show(io::IO, l::InstanceNorm)
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||||
@ -306,11 +274,11 @@ function Base.show(io::IO, l::InstanceNorm)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Group Normalization.
|
Group Normalization.
|
||||||
This layer can outperform Batch-Normalization and Instance-Normalization.
|
This layer can outperform Batch-Normalization and Instance-Normalization.
|
||||||
|
|
||||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||||
ϵ = 1f-5, momentum = 0.1f0)
|
ϵ = 1f-5, momentum = 0.1f0)
|
||||||
|
|
||||||
``chs`` is the number of channels, the channel dimension of your input.
|
``chs`` is the number of channels, the channel dimension of your input.
|
||||||
@ -322,12 +290,11 @@ The number of channels must be an integer multiple of the number of groups.
|
|||||||
Example:
|
Example:
|
||||||
```
|
```
|
||||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||||
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||||
```
|
```
|
||||||
|
|
||||||
Link : https://arxiv.org/pdf/1803.08494.pdf
|
Link : https://arxiv.org/pdf/1803.08494.pdf
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mutable struct GroupNorm{F,V,W,N,T}
|
mutable struct GroupNorm{F,V,W,N,T}
|
||||||
G::T # number of groups
|
G::T # number of groups
|
||||||
λ::F # activation function
|
λ::F # activation function
|
||||||
@ -337,13 +304,12 @@ mutable struct GroupNorm{F,V,W,N,T}
|
|||||||
σ²::W # moving std
|
σ²::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
active::Bool
|
|
||||||
end
|
end
|
||||||
|
|
||||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
|
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
||||||
zeros(G,1), ones(G,1), ϵ, momentum, true)
|
zeros(G,1), ones(G,1), ϵ, momentum)
|
||||||
|
|
||||||
function(gn::GroupNorm)(x)
|
function(gn::GroupNorm)(x)
|
||||||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||||
@ -355,20 +321,17 @@ function(gn::GroupNorm)(x)
|
|||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
batches = size(x,dims)
|
batches = size(x,dims)
|
||||||
channels_per_group = div(channels,groups)
|
channels_per_group = div(channels,groups)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||||
|
|
||||||
# Output reshaped to (W,H...,C/G,G,N)
|
# Output reshaped to (W,H...,C/G,G,N)
|
||||||
affine_shape[end-1] = channels
|
μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1)
|
||||||
|
|
||||||
μ_affine_shape = ones(Int,dims + 1)
|
|
||||||
μ_affine_shape[end-1] = groups
|
|
||||||
|
|
||||||
m = prod(size(x)[1:end-2]) * channels_per_group
|
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||||
γ = reshape(gn.γ, affine_shape...)
|
γ = reshape(gn.γ, affine_shape...)
|
||||||
β = reshape(gn.β, affine_shape...)
|
β = reshape(gn.β, affine_shape...)
|
||||||
|
|
||||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||||
if !gn.active
|
if !istraining()
|
||||||
og_shape = size(x)
|
og_shape = size(x)
|
||||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
@ -379,31 +342,29 @@ function(gn::GroupNorm)(x)
|
|||||||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(y, dims = axes)
|
μ = mean(y, dims = axes)
|
||||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||||
|
|
||||||
ϵ = data(convert(T, gn.ϵ))
|
|
||||||
# update moving mean/std
|
|
||||||
mtm = data(convert(T, gn.momentum))
|
|
||||||
|
|
||||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
|
ϵ = convert(T, gn.ϵ)
|
||||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
|
# update moving mean/std
|
||||||
|
mtm = gn.momentum
|
||||||
|
S = eltype(gn.μ)
|
||||||
|
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2)
|
||||||
|
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = gn.λ
|
let λ = gn.λ
|
||||||
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)
|
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
|
|
||||||
# Reshape x̂
|
# Reshape x̂
|
||||||
x̂ = reshape(x̂,og_shape)
|
x̂ = reshape(x̂,og_shape)
|
||||||
λ.(γ .* x̂ .+ β)
|
λ.(γ .* x̂ .+ β)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
children(gn::GroupNorm) =
|
children(gn::GroupNorm) =
|
||||||
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
|
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum)
|
||||||
|
|
||||||
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
||||||
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
|
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum)
|
||||||
|
|
||||||
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
|
|
||||||
|
|
||||||
function Base.show(io::IO, l::GroupNorm)
|
function Base.show(io::IO, l::GroupNorm)
|
||||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||||
|
@ -42,21 +42,6 @@ end
|
|||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
|
||||||
_truncate(x::Tuple) = _truncate.(x)
|
|
||||||
|
|
||||||
"""
|
|
||||||
truncate!(rnn)
|
|
||||||
|
|
||||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
|
||||||
state is preserved. See also `reset!`.
|
|
||||||
|
|
||||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
|
||||||
|
|
||||||
rnn.state = Tracker.data(rnn.state)
|
|
||||||
"""
|
|
||||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
reset!(rnn)
|
reset!(rnn)
|
||||||
|
|
||||||
@ -83,8 +68,8 @@ end
|
|||||||
|
|
||||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||||
init = glorot_uniform) =
|
init = glorot_uniform) =
|
||||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
RNNCell(σ, init(out, in), init(out, out),
|
||||||
param(init(out)), param(zeros(out)))
|
init(out), zeros(out))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||||
@ -122,9 +107,9 @@ end
|
|||||||
|
|
||||||
function LSTMCell(in::Integer, out::Integer;
|
function LSTMCell(in::Integer, out::Integer;
|
||||||
init = glorot_uniform)
|
init = glorot_uniform)
|
||||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
|
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
||||||
param(zeros(out)), param(zeros(out)))
|
zeros(out), zeros(out))
|
||||||
cell.b.data[gate(out, 2)] .= 1
|
cell.b[gate(out, 2)] .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -168,8 +153,8 @@ mutable struct GRUCell{A,V}
|
|||||||
end
|
end
|
||||||
|
|
||||||
GRUCell(in, out; init = glorot_uniform) =
|
GRUCell(in, out; init = glorot_uniform) =
|
||||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
GRUCell(init(out * 3, in), init(out * 3, out),
|
||||||
param(init(out*3)), param(zeros(out)))
|
init(out * 3), zeros(out))
|
||||||
|
|
||||||
function (m::GRUCell)(h, x)
|
function (m::GRUCell)(h, x)
|
||||||
b, o = m.b, size(h, 1)
|
b, o = m.b, size(h, 1)
|
||||||
|
@ -49,8 +49,3 @@ function normalise(x::AbstractArray; dims=1)
|
|||||||
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
||||||
return (x .- μ′) ./ σ′
|
return (x .- μ′) ./ σ′
|
||||||
end
|
end
|
||||||
|
|
||||||
function normalise(x::AbstractArray, dims)
|
|
||||||
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
|
||||||
normalise(x, dims = dims)
|
|
||||||
end
|
|
||||||
|
@ -54,17 +54,19 @@ it will error.
|
|||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux: onehot
|
||||||
|
|
||||||
julia> onehot(:b, [:a, :b, :c])
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
false
|
0
|
||||||
true
|
1
|
||||||
false
|
0
|
||||||
|
|
||||||
julia> onehot(:c, [:a, :b, :c])
|
julia> onehot(:c, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
false
|
0
|
||||||
false
|
0
|
||||||
true
|
1
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
function onehot(l, labels)
|
function onehot(l, labels)
|
||||||
@ -88,12 +90,13 @@ Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `label
|
|||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
julia> using Flux: onehotbatch
|
||||||
3×3 Flux.OneHotMatrix:
|
|
||||||
false true false
|
|
||||||
true false true
|
|
||||||
false false false
|
|
||||||
|
|
||||||
|
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||||
|
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
|
||||||
|
0 1 0
|
||||||
|
1 0 1
|
||||||
|
0 0 0
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
@ -106,9 +109,9 @@ Base.argmax(xs::OneHotVector) = xs.ix
|
|||||||
|
|
||||||
Inverse operations of [`onehot`](@ref).
|
Inverse operations of [`onehot`](@ref).
|
||||||
|
|
||||||
## Examples
|
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux: onecold
|
||||||
|
|
||||||
julia> onecold([true, false, false], [:a, :b, :c])
|
julia> onecold([true, false, false], [:a, :b, :c])
|
||||||
:a
|
:a
|
||||||
|
|
||||||
@ -124,15 +127,6 @@ onecold(y::AbstractMatrix, labels...) =
|
|||||||
onecold(y::OneHotMatrix, labels...) =
|
onecold(y::OneHotMatrix, labels...) =
|
||||||
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
|
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
|
||||||
|
|
||||||
function argmax(xs...)
|
# TODO probably still want this as a custom adjoint Zygote
|
||||||
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
|
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||||
return onecold(xs...)
|
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||||
end
|
|
||||||
|
|
||||||
# Ambiguity hack
|
|
||||||
|
|
||||||
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
|
|
||||||
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
|
|
||||||
|
|
||||||
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
|
||||||
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
|
||||||
|
@ -7,6 +7,5 @@ export train!,
|
|||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
include("deprecations.jl")
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,126 +0,0 @@
|
|||||||
using Base: depwarn
|
|
||||||
using Flux: Params
|
|
||||||
|
|
||||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
|
||||||
|
|
||||||
# legacy update rule
|
|
||||||
updaterule(opt, ps) = () -> _update_params!(opt, ps)
|
|
||||||
|
|
||||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
|
||||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
opt = Descent(η)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
|
|
||||||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
opt = Momentum(η, ρ)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
|
||||||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
opt = Nesterov(η, ρ)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
|
||||||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
opt = RMSProp(η, ρ)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
|
||||||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
β = (β1, β2)
|
|
||||||
opt = ADAM(η, β)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
|
|
||||||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
opt = ADAGrad(η)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
|
|
||||||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
opt = ADADelta(ρ)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
|
||||||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
β = (β1, β2)
|
|
||||||
opt = AdaMax(η, β)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
|
||||||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
β = (β1, β2)
|
|
||||||
opt = AMSGrad(η, β)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
|
||||||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
β = (β1, β2)
|
|
||||||
opt = NADAM(η, β)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
|
||||||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
|
||||||
|
|
||||||
ps = params
|
|
||||||
β = (β1, β2)
|
|
||||||
opt = ADAMW(η, β)
|
|
||||||
opt = check_decay(opt, decay)
|
|
||||||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
|
||||||
updaterule(opt, ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Old training loop
|
|
||||||
|
|
||||||
struct OldOptimiser
|
|
||||||
func
|
|
||||||
end
|
|
||||||
|
|
||||||
_update_params!(opt::OldOptimiser, ps) = opt.func()
|
|
||||||
|
|
||||||
# Train function
|
|
||||||
function train!(loss, data, opt; cb = () -> ())
|
|
||||||
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
|
||||||
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
|
||||||
end
|
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::Momentum, x, Δ)
|
function apply!(o::Momentum, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
@. v = ρ * v - η * Δ
|
@. v = ρ * v - η * Δ
|
||||||
@. Δ = -v
|
@. Δ = -v
|
||||||
end
|
end
|
||||||
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::Nesterov, x, Δ)
|
function apply!(o::Nesterov, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||||
@. v = ρ*v - η*Δ
|
@. v = ρ*v - η*Δ
|
||||||
@. Δ = -d
|
@. Δ = -d
|
||||||
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::RMSProp, x, Δ)
|
function apply!(o::RMSProp, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
@ -177,7 +177,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
|||||||
|
|
||||||
function apply!(o::ADAGrad, x, Δ)
|
function apply!(o::ADAGrad, x, Δ)
|
||||||
η = o.eta
|
η = o.eta
|
||||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||||
@. acc += Δ^2
|
@. acc += Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
@ -352,5 +352,5 @@ WeightDecay() = WeightDecay(0)
|
|||||||
|
|
||||||
function apply!(o::WeightDecay, x, Δ)
|
function apply!(o::WeightDecay, x, Δ)
|
||||||
wd = o.wd
|
wd = o.wd
|
||||||
@. Δ += wd * data(x)
|
@. Δ += wd * x
|
||||||
end
|
end
|
||||||
|
@ -1,32 +1,29 @@
|
|||||||
using Juno
|
using Juno
|
||||||
import Flux.Tracker: Params, gradient, data, update!
|
import Zygote: Params, gradient
|
||||||
import Base.depwarn
|
|
||||||
|
function update!(x::AbstractArray, x̄)
|
||||||
|
x .+= x̄
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
function update!(opt, x, x̄)
|
function update!(opt, x, x̄)
|
||||||
update!(x, -apply!(opt, x, data(x̄)))
|
x .-= apply!(opt, x, x̄)
|
||||||
end
|
end
|
||||||
|
|
||||||
function update!(opt, xs::Params, gs)
|
function update!(opt, xs::Params, gs)
|
||||||
for x in xs
|
for x in xs
|
||||||
|
gs[x] == nothing && continue
|
||||||
update!(opt, x, gs[x])
|
update!(opt, x, gs[x])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# Added as an internal API but everyone started using it.
|
|
||||||
function _update_params!(opt, xs)
|
|
||||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
|
||||||
for x in xs
|
|
||||||
update!(opt, x, Tracker.grad(x))
|
|
||||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# Callback niceties
|
# Callback niceties
|
||||||
call(f, xs...) = f(xs...)
|
call(f, xs...) = f(xs...)
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
struct StopException <: Exception end
|
struct StopException <: Exception end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
stop()
|
stop()
|
||||||
|
|
||||||
@ -72,10 +69,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||||||
loss(d...)
|
loss(d...)
|
||||||
end
|
end
|
||||||
update!(opt, ps, gs)
|
update!(opt, ps, gs)
|
||||||
if cb() == :stop
|
cb()
|
||||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
|
||||||
break
|
|
||||||
end
|
|
||||||
catch ex
|
catch ex
|
||||||
if ex isa StopException
|
if ex isa StopException
|
||||||
break
|
break
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import Adapt: adapt, adapt_storage
|
import Adapt: adapt, adapt_storage
|
||||||
import .Tracker: IdSet
|
import Zygote: IdSet
|
||||||
|
|
||||||
children(x) = ()
|
children(x) = ()
|
||||||
mapchildren(f, x) = x
|
mapchildren(f, x) = x
|
||||||
@ -40,7 +40,7 @@ end
|
|||||||
function params(m)
|
function params(m)
|
||||||
ps = Params()
|
ps = Params()
|
||||||
prefor(p ->
|
prefor(p ->
|
||||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
p isa AbstractArray{<:Real} &&
|
||||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||||
m)
|
m)
|
||||||
return ps
|
return ps
|
||||||
@ -52,7 +52,7 @@ function loadparams!(m, xs)
|
|||||||
for (p, x) in zip(params(m), xs)
|
for (p, x) in zip(params(m), xs)
|
||||||
size(p) == size(x) ||
|
size(p) == size(x) ||
|
||||||
error("Expected param size $(size(p)), got $(size(x))")
|
error("Expected param size $(size(p)), got $(size(x))")
|
||||||
copyto!(data(p), data(x))
|
copyto!(p, x)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -81,8 +81,6 @@ f64(m) = paramtype(Float64, m)
|
|||||||
|
|
||||||
function mapparams(f, m)
|
function mapparams(f, m)
|
||||||
mapleaves(m) do x
|
mapleaves(m) do x
|
||||||
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
x isa Union{AbstractArray,Number} ? f(x) : x
|
||||||
x isa Union{AbstractArray,Number} ? f(x) :
|
|
||||||
x
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using Flux, Flux.Tracker, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
using Flux: gpu
|
using Flux: gpu
|
||||||
|
|
||||||
@info "Testing GPU Support"
|
@info "Testing GPU Support"
|
||||||
@ -7,11 +7,11 @@ using Flux: gpu
|
|||||||
|
|
||||||
CuArrays.allowscalar(false)
|
CuArrays.allowscalar(false)
|
||||||
|
|
||||||
x = param(randn(5, 5))
|
x = randn(5, 5)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test cx isa TrackedArray && cx.data isa CuArray
|
@test cx isa CuArray
|
||||||
|
|
||||||
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
|
@test Flux.onecold(gpu([1.0, 2.0, 3.0])) == 3
|
||||||
|
|
||||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@ -21,24 +21,26 @@ cx = gpu(x)
|
|||||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
|
||||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
@test all(p isa CuArray for p in params(cm))
|
||||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
@test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
|
||||||
|
|
||||||
x = [1,2,3]
|
x = [1,2,3]
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||||
|
|
||||||
xs = param(rand(5,5))
|
xs = rand(5, 5)
|
||||||
ys = Flux.onehotbatch(1:5,1:5)
|
ys = Flux.onehotbatch(1:5,1:5)
|
||||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||||
|
|
||||||
c = gpu(Conv((2,2),3=>4))
|
c = gpu(Conv((2,2),3=>4))
|
||||||
|
x = gpu(rand(10, 10, 3, 2))
|
||||||
l = c(gpu(rand(10,10,3,2)))
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
Flux.back!(sum(l))
|
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||||
|
|
||||||
c = gpu(CrossCor((2,2),3=>4))
|
c = gpu(CrossCor((2,2),3=>4))
|
||||||
|
x = gpu(rand(10, 10, 3, 2))
|
||||||
l = c(gpu(rand(10,10,3,2)))
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
Flux.back!(sum(l))
|
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,48 +1,44 @@
|
|||||||
using Flux, Flux.Tracker, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
using Flux.Tracker: TrackedArray, data
|
using Flux: forward
|
||||||
|
|
||||||
@testset "CUDNN BatchNorm" begin
|
@testset "CUDNN BatchNorm" begin
|
||||||
@testset "4D Input" begin
|
@testset "4D Input" begin
|
||||||
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
|
x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
|
||||||
m = BatchNorm(3)
|
m = BatchNorm(3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
|
||||||
y = m(x)
|
y, back = forward((m, x) -> m(x), m, x)
|
||||||
cy = cm(cx)
|
cy, cback = forward((m, x) -> m(x), cm, cx)
|
||||||
|
|
||||||
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
@test cpu(cy) ≈ y
|
||||||
|
|
||||||
@test cpu(data(cy)) ≈ data(y)
|
Δ = randn(size(y))
|
||||||
|
dm, dx = back(Δ)
|
||||||
|
cdm, cdx = cback(gpu(Δ))
|
||||||
|
|
||||||
g = rand(size(y)...)
|
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||||
Flux.back!(y, g)
|
@test dm[].β ≈ cpu(cdm[].β)
|
||||||
Flux.back!(cy, gpu(g))
|
@test dx ≈ cpu(cdx)
|
||||||
|
|
||||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
|
||||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
|
||||||
@test x.grad ≈ cpu(x.grad)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "2D Input" begin
|
@testset "2D Input" begin
|
||||||
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
|
x = Float64.(collect(reshape(1:12, 3, 4)))
|
||||||
m = BatchNorm(3)
|
m = BatchNorm(3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
|
||||||
y = m(x)
|
y, back = forward((m, x) -> m(x), m, x)
|
||||||
cy = cm(cx)
|
cy, cback = forward((m, x) -> m(x), cm, cx)
|
||||||
|
|
||||||
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
@test cpu(cy) ≈ y
|
||||||
|
|
||||||
@test cpu(data(cy)) ≈ data(y)
|
Δ = randn(size(y))
|
||||||
|
dm, dx = back(Δ)
|
||||||
|
cdm, cdx = cback(gpu(Δ))
|
||||||
|
|
||||||
g = rand(size(y)...)
|
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||||
Flux.back!(y, g)
|
@test dm[].β ≈ cpu(cdm[].β)
|
||||||
Flux.back!(cy, gpu(g))
|
@test dx ≈ cpu(cdx)
|
||||||
|
|
||||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
|
||||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
|
||||||
@test x.grad ≈ cpu(x.grad)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,46 +1,54 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
|
using Flux: forward
|
||||||
|
|
||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
@testset for R in [RNN, GRU, LSTM]
|
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
||||||
rnn = R(10, 5)
|
rnn = R(10, 5)
|
||||||
curnn = mapleaves(gpu, rnn)
|
curnn = mapleaves(gpu, rnn)
|
||||||
@testset for batch_size in (1, 5)
|
|
||||||
Flux.reset!(rnn)
|
|
||||||
Flux.reset!(curnn)
|
|
||||||
x = batch_size == 1 ?
|
|
||||||
param(rand(10)) :
|
|
||||||
param(rand(10,batch_size))
|
|
||||||
cux = gpu(x)
|
|
||||||
y = (rnn(x); rnn(x))
|
|
||||||
cuy = (curnn(cux); curnn(cux))
|
|
||||||
|
|
||||||
@test y.data ≈ collect(cuy.data)
|
Flux.reset!(rnn)
|
||||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
Flux.reset!(curnn)
|
||||||
|
x = batch_size == 1 ?
|
||||||
|
rand(10) :
|
||||||
|
rand(10, batch_size)
|
||||||
|
cux = gpu(x)
|
||||||
|
|
||||||
Δ = randn(size(y))
|
y, back = forward((r, x) -> (r(x)), rnn, x)
|
||||||
|
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux)
|
||||||
|
|
||||||
Flux.back!(y, Δ)
|
@test y ≈ collect(cuy)
|
||||||
Flux.back!(cuy, gpu(Δ))
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||||
|
|
||||||
@test x.grad ≈ collect(cux.grad)
|
ȳ = randn(size(y))
|
||||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
m̄, x̄ = back(ȳ)
|
||||||
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
cum̄, cux̄ = cuback(gpu(ȳ))
|
||||||
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
|
||||||
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
m̄[].cell[].Wi
|
||||||
if isdefined(rnn.cell, :c)
|
|
||||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
m̄[].state
|
||||||
|
cum̄[].state
|
||||||
|
|
||||||
|
@test x̄ ≈ collect(cux̄)
|
||||||
|
@test m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi)
|
||||||
|
@test m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh)
|
||||||
|
@test m̄[].cell[].b ≈ collect(cum̄[].cell[].b)
|
||||||
|
if m̄[].state isa Tuple
|
||||||
|
for (x, cx) in zip(m̄[].state, cum̄[].state)
|
||||||
|
@test x ≈ collect(cx)
|
||||||
end
|
end
|
||||||
|
else
|
||||||
Flux.reset!(rnn)
|
@test m̄[].state ≈ collect(cum̄[].state)
|
||||||
Flux.reset!(curnn)
|
|
||||||
ohx = batch_size == 1 ?
|
|
||||||
Flux.onehot(rand(1:10), 1:10) :
|
|
||||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
|
||||||
cuohx = gpu(ohx)
|
|
||||||
y = (rnn(ohx); rnn(ohx))
|
|
||||||
cuy = (curnn(cuohx); curnn(cuohx))
|
|
||||||
|
|
||||||
@test y.data ≈ collect(cuy.data)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
ohx = batch_size == 1 ?
|
||||||
|
Flux.onehot(rand(1:10), 1:10) :
|
||||||
|
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||||
|
cuohx = gpu(ohx)
|
||||||
|
y = (rnn(ohx); rnn(ohx))
|
||||||
|
cuy = (curnn(cuohx); curnn(cuohx))
|
||||||
|
|
||||||
|
@test y ≈ collect(cuy)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -25,9 +25,9 @@ end
|
|||||||
@testset "asymmetric padding" begin
|
@testset "asymmetric padding" begin
|
||||||
r = ones(Float32, 28, 28, 1, 1)
|
r = ones(Float32, 28, 28, 1, 1)
|
||||||
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
|
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
|
||||||
m.weight.data[:] .= 1.0
|
m.weight[:] .= 1.0
|
||||||
m.bias.data[:] .= 0.0
|
m.bias[:] .= 0.0
|
||||||
y_hat = Flux.data(m(r))[:,:,1,1]
|
y_hat = m(r)[:,:,1,1]
|
||||||
@test size(y_hat) == (27, 29)
|
@test size(y_hat) == (27, 29)
|
||||||
@test y_hat[1, 1] ≈ 6.0
|
@test y_hat[1, 1] ≈ 6.0
|
||||||
@test y_hat[2, 2] ≈ 9.0
|
@test y_hat[2, 2] ≈ 9.0
|
||||||
@ -41,7 +41,7 @@ end
|
|||||||
r = zeros(Float32, 28, 28, 3, 5)
|
r = zeros(Float32, 28, 28, 3, 5)
|
||||||
m1 = DepthwiseConv((2, 2), 3=>15)
|
m1 = DepthwiseConv((2, 2), 3=>15)
|
||||||
@test size(m1(r), 3) == 15
|
@test size(m1(r), 3) == 15
|
||||||
|
|
||||||
m3 = DepthwiseConv((2, 3), 3=>9)
|
m3 = DepthwiseConv((2, 3), 3=>9)
|
||||||
@test size(m3(r), 3) == 9
|
@test size(m3(r), 3) == 9
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ end
|
|||||||
y = CrossCor(w, [0.0])
|
y = CrossCor(w, [0.0])
|
||||||
|
|
||||||
@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]
|
@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]
|
||||||
|
|
||||||
r = zeros(Float32, 28, 28, 1, 5)
|
r = zeros(Float32, 28, 28, 1, 5)
|
||||||
m = Chain(
|
m = Chain(
|
||||||
CrossCor((2, 2), 1=>16, relu),
|
CrossCor((2, 2), 1=>16, relu),
|
||||||
@ -102,4 +102,3 @@ end
|
|||||||
true
|
true
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,29 +1,29 @@
|
|||||||
using Flux: testmode!
|
using Flux, Test, Statistics
|
||||||
using Flux.Tracker: data
|
using Zygote: forward
|
||||||
|
|
||||||
|
trainmode(f, x...) = forward(f, x...)[1]
|
||||||
|
trainmode(f) = (x...) -> trainmode(f, x...)
|
||||||
|
|
||||||
@testset "Dropout" begin
|
@testset "Dropout" begin
|
||||||
x = [1.,2.,3.]
|
x = [1.,2.,3.]
|
||||||
@test x == testmode!(Dropout(0.1))(x)
|
@test x == Dropout(0.1)(x)
|
||||||
@test x == Dropout(0)(x)
|
@test x == trainmode(Dropout(0), x)
|
||||||
@test zero(x) == Dropout(1)(x)
|
@test zero(x) == trainmode(Dropout(1), x)
|
||||||
|
|
||||||
x = rand(100)
|
x = rand(100)
|
||||||
m = Dropout(0.9)
|
m = Dropout(0.9)
|
||||||
y = m(x)
|
y = trainmode(m, x)
|
||||||
@test count(a->a==0, y) > 50
|
@test count(a->a==0, y) > 50
|
||||||
testmode!(m)
|
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a==0, y) == 0
|
@test count(a->a==0, y) == 0
|
||||||
testmode!(m, false)
|
y = trainmode(m, x)
|
||||||
y = m(x)
|
|
||||||
@test count(a->a==0, y) > 50
|
@test count(a->a==0, y) > 50
|
||||||
|
|
||||||
x = rand(100)
|
x = rand(Float32, 100)
|
||||||
m = Chain(Dense(100,100),
|
m = Chain(Dense(100,100),
|
||||||
Dropout(0.9))
|
Dropout(0.9))
|
||||||
y = m(x)
|
y = trainmode(m, x)
|
||||||
@test count(a->a == 0, y) > 50
|
@test count(a->a == 0, y) > 50
|
||||||
testmode!(m)
|
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a == 0, y) == 0
|
@test count(a->a == 0, y) == 0
|
||||||
|
|
||||||
@ -39,18 +39,16 @@ using Flux.Tracker: data
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "BatchNorm" begin
|
@testset "BatchNorm" begin
|
||||||
let m = BatchNorm(2), x = param([1 3 5;
|
let m = BatchNorm(2), x = [1.0 3.0 5.0;
|
||||||
2 4 6])
|
2.0 4.0 6.0]
|
||||||
|
|
||||||
@test m.β.data == [0, 0] # initβ(2)
|
@test m.β == [0, 0] # initβ(2)
|
||||||
@test m.γ.data == [1, 1] # initγ(2)
|
@test m.γ == [1, 1] # initγ(2)
|
||||||
# initial m.σ is 1
|
# initial m.σ is 1
|
||||||
# initial m.μ is 0
|
# initial m.μ is 0
|
||||||
@test m.active
|
|
||||||
|
|
||||||
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
|
||||||
m(x)
|
|
||||||
|
|
||||||
|
y = trainmode(m, x)
|
||||||
|
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
|
||||||
# julia> x
|
# julia> x
|
||||||
# 2×3 Array{Float64,2}:
|
# 2×3 Array{Float64,2}:
|
||||||
# 1.0 3.0 5.0
|
# 1.0 3.0 5.0
|
||||||
@ -69,41 +67,32 @@ end
|
|||||||
# 2×1 Array{Float64,2}:
|
# 2×1 Array{Float64,2}:
|
||||||
# 1.3
|
# 1.3
|
||||||
# 1.3
|
# 1.3
|
||||||
@test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
@test m.σ² ≈ .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
|
|
||||||
testmode!(m)
|
x′ = m(x)
|
||||||
@test !m.active
|
|
||||||
|
|
||||||
x′ = m(x).data
|
|
||||||
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
|
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
|
||||||
end
|
end
|
||||||
|
|
||||||
# with activation function
|
# with activation function
|
||||||
let m = BatchNorm(2, sigmoid), x = param([1 3 5;
|
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
|
||||||
2 4 6])
|
2.0 4.0 6.0]
|
||||||
@test m.active
|
y = m(x)
|
||||||
m(x)
|
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||||
|
|
||||||
testmode!(m)
|
|
||||||
@test !m.active
|
|
||||||
|
|
||||||
y = m(x).data
|
|
||||||
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
|
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
@ -115,20 +104,16 @@ end
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@testset "InstanceNorm" begin
|
@testset "InstanceNorm" begin
|
||||||
# helper functions
|
# helper functions
|
||||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
# begin tests
|
# begin tests
|
||||||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = reshape(collect(1:prod(sizes)), sizes)
|
||||||
|
x = Float64.(x)
|
||||||
@test m.β.data == [0, 0] # initβ(2)
|
@test m.β == [0, 0] # initβ(2)
|
||||||
@test m.γ.data == [1, 1] # initγ(2)
|
@test m.γ == [1, 1] # initγ(2)
|
||||||
|
y = trainmode(m, x)
|
||||||
@test m.active
|
|
||||||
|
|
||||||
m(x)
|
|
||||||
|
|
||||||
#julia> x
|
#julia> x
|
||||||
#[:, :, 1] =
|
#[:, :, 1] =
|
||||||
@ -153,37 +138,28 @@ end
|
|||||||
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||||
@test m.μ ≈ [0.5, 0.8]
|
@test m.μ ≈ [0.5, 0.8]
|
||||||
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||||
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
# julia> reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||||
# 2-element Array{Float64,1}:
|
# 2-element Array{Float64,1}:
|
||||||
# 1.
|
# 1.
|
||||||
# 1.
|
# 1.
|
||||||
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
@test m.σ² ≈ reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||||
|
|
||||||
testmode!(m)
|
x′ = m(x)
|
||||||
@test !m.active
|
|
||||||
|
|
||||||
x′ = m(x).data
|
|
||||||
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||||
end
|
end
|
||||||
# with activation function
|
# with activation function
|
||||||
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = reshape(collect(1:prod(sizes)), sizes)
|
||||||
|
x = Float64.(x)
|
||||||
affine_shape = collect(sizes)
|
affine_shape = collect(sizes)
|
||||||
affine_shape[1] = 1
|
affine_shape[1] = 1
|
||||||
|
|
||||||
@test m.active
|
y = m(x)
|
||||||
m(x)
|
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||||
|
|
||||||
testmode!(m)
|
|
||||||
@test !m.active
|
|
||||||
|
|
||||||
y = m(x).data
|
|
||||||
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
@ -191,16 +167,16 @@ end
|
|||||||
|
|
||||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||||
y = m(x)
|
y = trainmode(m, x)
|
||||||
@test size(m.μ) == (sizes[end - 1], )
|
@test size(m.μ) == (sizes[end - 1], )
|
||||||
@test size(m.σ²) == (sizes[end - 1], )
|
@test size(m.σ²) == (sizes[end - 1], )
|
||||||
@test size(y) == sizes
|
@test size(y) == sizes
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||||
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -216,14 +192,12 @@ end
|
|||||||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
||||||
|
|
||||||
let m = GroupNorm(4,2), sizes = (3,4,2),
|
let m = GroupNorm(4,2), sizes = (3,4,2),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = reshape(collect(1:prod(sizes)), sizes)
|
||||||
|
x = Float64.(x)
|
||||||
|
@test m.β == [0, 0, 0, 0] # initβ(32)
|
||||||
|
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
||||||
|
|
||||||
@test m.β.data == [0, 0, 0, 0] # initβ(32)
|
y = trainmode(m, x)
|
||||||
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
|
|
||||||
|
|
||||||
@test m.active
|
|
||||||
|
|
||||||
m(x)
|
|
||||||
|
|
||||||
#julia> x
|
#julia> x
|
||||||
#[:, :, 1] =
|
#[:, :, 1] =
|
||||||
@ -243,7 +217,7 @@ end
|
|||||||
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
|
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
|
||||||
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
|
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
|
||||||
#
|
#
|
||||||
# μ =
|
# μ =
|
||||||
# 3.5 15.5
|
# 3.5 15.5
|
||||||
# 9.5 21.5
|
# 9.5 21.5
|
||||||
#
|
#
|
||||||
@ -253,46 +227,37 @@ end
|
|||||||
@test m.μ ≈ [0.95, 1.55]
|
@test m.μ ≈ [0.95, 1.55]
|
||||||
|
|
||||||
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
|
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
|
||||||
# 2-element Array{Tracker.TrackedReal{Float64},1}:
|
# 2-element Array{Float64,1}:
|
||||||
# 1.25
|
# 1.25
|
||||||
# 1.25
|
# 1.25
|
||||||
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
||||||
|
|
||||||
testmode!(m)
|
x′ = m(x)
|
||||||
@test !m.active
|
|
||||||
|
|
||||||
x′ = m(x).data
|
|
||||||
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
||||||
end
|
end
|
||||||
# with activation function
|
# with activation function
|
||||||
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
|
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = reshape(collect(1:prod(sizes)), sizes)
|
||||||
|
x = Float64.(x)
|
||||||
μ_affine_shape = ones(Int,length(sizes) + 1)
|
μ_affine_shape = ones(Int,length(sizes) + 1)
|
||||||
μ_affine_shape[end-1] = 2 # Number of groups
|
μ_affine_shape[end-1] = 2 # Number of groups
|
||||||
|
|
||||||
affine_shape = ones(Int,length(sizes) + 1)
|
affine_shape = ones(Int,length(sizes) + 1)
|
||||||
affine_shape[end-2] = 2 # Channels per group
|
affine_shape[end-2] = 2 # Channels per group
|
||||||
affine_shape[end-1] = 2 # Number of groups
|
affine_shape[end-1] = 2 # Number of groups
|
||||||
affine_shape[1] = sizes[1]
|
affine_shape[1] = sizes[1]
|
||||||
affine_shape[end] = sizes[end]
|
affine_shape[end] = sizes[end]
|
||||||
|
|
||||||
og_shape = size(x)
|
og_shape = size(x)
|
||||||
|
|
||||||
@test m.active
|
|
||||||
m(x)
|
|
||||||
|
|
||||||
testmode!(m)
|
|
||||||
@test !m.active
|
|
||||||
|
|
||||||
y = m(x)
|
y = m(x)
|
||||||
x_ = reshape(x,affine_shape...)
|
x_ = reshape(x,affine_shape...)
|
||||||
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
|
out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape)
|
||||||
@test isapprox(y, out, atol = 1.0e-7)
|
@test isapprox(y, out, atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
|
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
@ -300,22 +265,22 @@ end
|
|||||||
|
|
||||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = m(x)
|
y = trainmode(m, x)
|
||||||
@test size(m.μ) == (m.G,1)
|
@test size(m.μ) == (m.G,1)
|
||||||
@test size(m.σ²) == (m.G,1)
|
@test size(m.σ²) == (m.G,1)
|
||||||
@test size(y) == sizes
|
@test size(y) == sizes
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||||
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
|
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test IN(x) ≈ GN(x)
|
@test IN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||||
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
|
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test BN(x) ≈ GN(x)
|
@test BN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -51,13 +51,13 @@ const ϵ = 1e-7
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "no spurious promotions" begin
|
@testset "no spurious promotions" begin
|
||||||
for T in (Float16, Float32, Float64)
|
for T in (Float32, Float64)
|
||||||
y = rand(T, 2)
|
y = rand(T, 2)
|
||||||
ŷ = rand(T, 2)
|
ŷ = rand(T, 2)
|
||||||
for f in (mse, crossentropy, logitcrossentropy)
|
for f in (mse, crossentropy, logitcrossentropy)
|
||||||
fwd, back = Flux.Tracker.forward(mse, ŷ, y)
|
fwd, back = Flux.forward(f, ŷ, y)
|
||||||
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
|
@test fwd isa T
|
||||||
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
|
@test eltype(back(one(T))[1]) == T
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,42 +1,44 @@
|
|||||||
using Flux.Optimise
|
using Flux.Optimise
|
||||||
using Flux.Optimise: runall
|
using Flux.Optimise: runall
|
||||||
using Flux.Tracker
|
using Flux: Params, gradient
|
||||||
using Test
|
using Test
|
||||||
|
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||||
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||||
Momentum()]
|
Momentum()]
|
||||||
w′ = param(randn(10, 10))
|
w′ = randn(10, 10)
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
for t = 1: 10^5
|
for t = 1: 10^5
|
||||||
θ = Params([w′])
|
θ = Params([w′])
|
||||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
x = rand(10)
|
||||||
|
θ̄ = gradient(() -> loss(x), θ)
|
||||||
Optimise.update!(opt, θ, θ̄)
|
Optimise.update!(opt, θ, θ̄)
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test loss(rand(10, 10)) < 0.01
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Optimiser" begin
|
@testset "Optimiser" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||||
w′ = param(randn(10, 10))
|
w′ = randn(10, 10)
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Optimiser(Opt(), ADAM(0.001))
|
opt = Optimiser(Opt(), ADAM(0.001))
|
||||||
for t = 1:10^5
|
for t = 1:10^5
|
||||||
l = loss(rand(10))
|
θ = Params([w′])
|
||||||
back!(l)
|
x = rand(10)
|
||||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
θ̄ = gradient(() -> loss(x), θ)
|
||||||
w′.data .-= delta
|
Optimise.update!(opt, θ, θ̄)
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test loss(rand(10, 10)) < 0.01
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Training Loop" begin
|
@testset "Training Loop" begin
|
||||||
i = 0
|
i = 0
|
||||||
l = param(1)
|
l = 1
|
||||||
|
|
||||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
(),
|
(),
|
||||||
@ -57,17 +59,18 @@ end
|
|||||||
@testset "ExpDecay" begin
|
@testset "ExpDecay" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
|
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
|
||||||
w1 = param(randn(10,10))
|
w1 = randn(10,10)
|
||||||
loss(x) = Flux.mse(w*x, w1*x)
|
loss(x) = Flux.mse(w*x, w1*x)
|
||||||
flag = 1
|
flag = 1
|
||||||
decay_steps = []
|
decay_steps = []
|
||||||
for t = 1:10^5
|
for t = 1:10^5
|
||||||
l = loss(rand(10))
|
|
||||||
back!(l)
|
|
||||||
prev_eta = o.eta
|
prev_eta = o.eta
|
||||||
prev_grad = collect(w1.grad)
|
θ = Params([w1])
|
||||||
delta = Optimise.apply!(o, w1.data, w1.grad)
|
x = rand(10)
|
||||||
w1.data .-= delta
|
θ̄ = gradient(() -> loss(x), θ)
|
||||||
|
prev_grad = collect(θ̄[w1])
|
||||||
|
delta = Optimise.apply!(o, w1, θ̄[w1])
|
||||||
|
w1 .-= delta
|
||||||
new_eta = o.eta
|
new_eta = o.eta
|
||||||
if new_eta != prev_eta
|
if new_eta != prev_eta
|
||||||
push!(decay_steps, t)
|
push!(decay_steps, t)
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
using Flux, Test, Random, Statistics
|
using Flux, Test, Random, Statistics, Documenter
|
||||||
using Random
|
using Random
|
||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
|
||||||
# So we can use the system CuArrays
|
|
||||||
insert!(LOAD_PATH, 2, "@v#.#")
|
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
@info "Testing Basics"
|
@info "Testing Basics"
|
||||||
@ -22,14 +19,14 @@ include("layers/normalisation.jl")
|
|||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/conv.jl")
|
include("layers/conv.jl")
|
||||||
|
|
||||||
@info "Running Gradient Checks"
|
|
||||||
|
|
||||||
include("tracker.jl")
|
|
||||||
|
|
||||||
if isdefined(Flux, :CUDA)
|
if isdefined(Flux, :CUDA)
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
else
|
else
|
||||||
@warn "CUDA unavailable, not testing GPU support"
|
@warn "CUDA unavailable, not testing GPU support"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if VERSION >= v"1.2"
|
||||||
|
doctest(Flux)
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
using Flux, Test
|
|
||||||
using Tracker: gradcheck
|
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|
||||||
|
|
||||||
@testset "Tracker" begin
|
|
||||||
|
|
||||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
|
||||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
|
||||||
|
|
||||||
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
|
||||||
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
|
||||||
|
|
||||||
end
|
|
@ -1,5 +1,5 @@
|
|||||||
using Flux
|
using Flux
|
||||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack
|
using Flux: throttle, glorot_uniform, glorot_normal, stack, unstack
|
||||||
using StatsBase: std
|
using StatsBase: std
|
||||||
using Random
|
using Random
|
||||||
using Test
|
using Test
|
||||||
@ -52,15 +52,6 @@ using Test
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Jacobian" begin
|
|
||||||
A = param(randn(2,2))
|
|
||||||
x = randn(2)
|
|
||||||
m(x) = A*x
|
|
||||||
y = m(x)
|
|
||||||
J = jacobian(m,x)
|
|
||||||
@test J ≈ A.data
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Initialization" begin
|
@testset "Initialization" begin
|
||||||
# Set random seed so that these tests don't fail randomly
|
# Set random seed so that these tests don't fail randomly
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
@ -106,12 +97,11 @@ end
|
|||||||
@testset "Precision" begin
|
@testset "Precision" begin
|
||||||
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||||
x = rand(10)
|
x = rand(10)
|
||||||
@test eltype(m[1].W.data) == Float32
|
@test eltype(m[1].W) == Float32
|
||||||
@test eltype(m(x).data) == Float32
|
@test eltype(m(x)) == Float32
|
||||||
@test eltype(f64(m)(x).data) == Float64
|
@test eltype(f64(m)(x)) == Float64
|
||||||
@test eltype(f64(m)[1].W.data) == Float64
|
@test eltype(f64(m)[1].W) == Float64
|
||||||
@test eltype(f32(f64(m))[1].W.data) == Float32
|
@test eltype(f32(f64(m))[1].W) == Float32
|
||||||
@test Tracker.isleaf(f32(f64(m))[1].W)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Stacking" begin
|
@testset "Stacking" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user