commit
b348e31f07
|
@ -1,3 +1,5 @@
|
|||
# This file is machine-generated - editing it directly is not advised
|
||||
|
||||
[[AbstractTrees]]
|
||||
deps = ["Markdown", "Test"]
|
||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||
|
@ -27,9 +29,9 @@ version = "0.5.3"
|
|||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
|
||||
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
|
@ -51,9 +53,9 @@ version = "0.2.0"
|
|||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
|
||||
git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "1.5.1"
|
||||
version = "2.0.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
|
@ -71,18 +73,18 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
|||
|
||||
[[DiffResults]]
|
||||
deps = ["Compat", "StaticArrays"]
|
||||
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
|
||||
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
|
||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "0.0.3"
|
||||
version = "0.0.4"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9"
|
||||
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "0.0.8"
|
||||
version = "0.0.10"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
|
@ -93,12 +95,12 @@ version = "0.5.3"
|
|||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
|
||||
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.2"
|
||||
version = "0.10.3"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
|
@ -122,9 +124,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
|||
|
||||
[[MacroTools]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
|
||||
git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
|
@ -147,7 +149,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
|||
|
||||
[[NNlib]]
|
||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f"
|
||||
git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/FluxML/NNlib.jl.git"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
|
@ -228,9 +230,9 @@ version = "0.7.2"
|
|||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
|
||||
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.10.2"
|
||||
version = "0.10.3"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
|
@ -238,19 +240,25 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
|||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
|
||||
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.27.0"
|
||||
version = "0.29.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[Tracker]]
|
||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||
git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df"
|
||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
version = "0.1.0"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
||||
git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.8.1"
|
||||
version = "0.9.1"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
|
@ -259,7 +267,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
|||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random"]
|
||||
deps = ["Random", "SHA"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
|
|
|
@ -6,21 +6,18 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
|||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
|
|
6
REQUIRE
6
REQUIRE
|
@ -10,9 +10,3 @@ ZipFile
|
|||
AbstractTrees
|
||||
Reexport
|
||||
StatsBase
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
DiffRules
|
||||
SpecialFunctions
|
||||
NaNMath
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# This file is machine-generated - editing it directly is not advised
|
||||
|
||||
[[AbstractTrees]]
|
||||
deps = ["Markdown", "Test"]
|
||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||
|
@ -6,9 +8,9 @@ version = "0.2.1"
|
|||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
|
||||
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
@ -27,9 +29,9 @@ version = "0.5.3"
|
|||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
|
||||
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
|
@ -51,9 +53,9 @@ version = "0.2.0"
|
|||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
|
||||
git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "1.4.0"
|
||||
version = "2.0.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
|
@ -71,18 +73,18 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
|||
|
||||
[[DiffResults]]
|
||||
deps = ["Compat", "StaticArrays"]
|
||||
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
|
||||
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
|
||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "0.0.3"
|
||||
version = "0.0.4"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
|
||||
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "0.0.7"
|
||||
version = "0.0.10"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[DocStringExtensions]]
|
||||
|
@ -93,9 +95,9 @@ version = "0.6.0"
|
|||
|
||||
[[Documenter]]
|
||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
||||
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
|
||||
git-tree-sha1 = "a8c41ba3d0861240dbec942ee1d0f86c57c37c1c"
|
||||
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
version = "0.21.0"
|
||||
version = "0.21.5"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
deps = ["Test"]
|
||||
|
@ -104,26 +106,26 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
|||
version = "0.5.3"
|
||||
|
||||
[[Flux]]
|
||||
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
|
||||
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Test", "Tracker", "ZipFile"]
|
||||
path = ".."
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.6.10+"
|
||||
version = "0.7.3+"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
|
||||
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.1"
|
||||
version = "0.10.3"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
|
||||
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.5.3"
|
||||
version = "0.5.4"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
@ -140,9 +142,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
|||
|
||||
[[MacroTools]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
|
||||
git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
|
@ -156,9 +158,9 @@ version = "0.5.0"
|
|||
|
||||
[[Missings]]
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
|
||||
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.3.1"
|
||||
version = "0.4.0"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
@ -244,9 +246,9 @@ version = "0.7.2"
|
|||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
|
||||
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.10.2"
|
||||
version = "0.10.3"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
|
@ -254,19 +256,25 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
|||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
|
||||
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.27.0"
|
||||
version = "0.29.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[Tracker]]
|
||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||
git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df"
|
||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
version = "0.1.0"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
||||
git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.8.1"
|
||||
version = "0.9.1"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
|
@ -275,7 +283,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
|||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random"]
|
||||
deps = ["Random", "SHA"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
|
|
|
@ -12,9 +12,8 @@ export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
|||
|
||||
@reexport using NNlib
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
using .Tracker: data
|
||||
using Tracker
|
||||
using Tracker: data
|
||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
|
|
|
@ -1,114 +0,0 @@
|
|||
module Tracker
|
||||
|
||||
using MacroTools
|
||||
using MacroTools: @q, @forward
|
||||
|
||||
import Base: ==
|
||||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
||||
jacobian, hessian, param, back!
|
||||
|
||||
tracker(x) = nothing
|
||||
|
||||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
grad(x) = grad(tracker(x))
|
||||
grad(::Nothing) = nothing
|
||||
data(x) = x
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
args::As
|
||||
end
|
||||
|
||||
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
||||
Call() = Call(nothing, ())
|
||||
|
||||
# When deserialising, the object_id changes
|
||||
a::Call == b::Call = a.func == b.func && a.args == b.args
|
||||
|
||||
@inline (c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
isleaf::Bool
|
||||
grad::T
|
||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||
end
|
||||
|
||||
istracked(x::Tracked) = true
|
||||
isleaf(x::Tracked) = x.f == Call()
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||
|
||||
function _forward end
|
||||
|
||||
function track(f::F, xs...; kw...) where F
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
||||
macro grad(ex)
|
||||
@capture(shortdef(ex), (name_(args__) = body_) |
|
||||
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
||||
T == nothing && (T = [])
|
||||
isexpr(name, :(::)) || (name = :(::typeof($name)))
|
||||
insert!(args, 1+isexpr(args[1], :parameters) , name)
|
||||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("params.jl")
|
||||
include("back.jl")
|
||||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
include("lib/array.jl")
|
||||
include("forward.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||
|
||||
"""
|
||||
checkpoint(f, args...)
|
||||
|
||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||
backward pass. This can be used to save memory in larger models.
|
||||
"""
|
||||
checkpoint(f, args...) = track(checkpoint, f, args...)
|
||||
|
||||
@grad function checkpoint(f, args...)
|
||||
data(f(args...)), function (Δ)
|
||||
y, back = forward(f, args...)
|
||||
(nothing, back(Δ)...)
|
||||
end
|
||||
end
|
||||
|
||||
nobacksies(f, x) = track(nobacksies, f, x)
|
||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
@grad identity(x) = data(x), Δ -> (Δ,)
|
||||
param(x::TrackedReal) = track(identity, x)
|
||||
param(x::TrackedArray) = track(identity, x)
|
||||
|
||||
import Adapt: adapt, adapt_structure
|
||||
|
||||
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
|
@ -1,190 +0,0 @@
|
|||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
# In-place gradients
|
||||
|
||||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::Tracked)
|
||||
x.isleaf && return
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
function scan(x)
|
||||
istracked(x) && scan(tracker(x))
|
||||
return
|
||||
end
|
||||
|
||||
function back_(c::Call, Δ, once)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Nothing}, Δ, once) = nothing
|
||||
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
grad = if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
elseif ref > 0
|
||||
x.grad = Δ
|
||||
else
|
||||
Δ
|
||||
end
|
||||
if ref == 0
|
||||
back_(x.f, grad, once)
|
||||
once && !x.isleaf && (x.f = Call(missing, ()))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Nothing, Δ, once) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
# (but only if you re-use intermediate values between passes)
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x, Δ; once = true)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
back(tracker(x), Δ, once)
|
||||
return
|
||||
end
|
||||
|
||||
function extract_grad!(x)
|
||||
x̄ = copy(grad(x))
|
||||
x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
|
||||
tracker(x).grad = zero_grad!(grad(x))
|
||||
return x̄
|
||||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(data.(xs))
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
@interrupts back!(l)
|
||||
extract_grad!.(xs)
|
||||
end
|
||||
|
||||
function gradient_(f, xs::Params)
|
||||
l = f()
|
||||
losscheck(l)
|
||||
@interrupts back!(l)
|
||||
gs = Grads()
|
||||
for x in xs
|
||||
gs[tracker(x)] = extract_grad!(x)
|
||||
end
|
||||
return gs
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
x.isleaf && (accum!(g, x, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if ref > 0 || haskey(g, x)
|
||||
accum!(g, x, Δ)
|
||||
ref == 0 && back_(g, x.f, g[x])
|
||||
else
|
||||
ref == 0 && back_(g, x.f, Δ)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Grads, ::Nothing, _) = return
|
||||
|
||||
collectmemaybe(xs) = xs
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = collectmemaybe(f())
|
||||
y, function (Δ)
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
back(g, tracker(y), Δ)
|
||||
end
|
||||
return g
|
||||
end
|
||||
end
|
||||
|
||||
function forward(f, args...)
|
||||
args = param.(args)
|
||||
y, back = forward(() -> f(args...), Params(args))
|
||||
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||
end
|
||||
|
||||
function losscheck(x)
|
||||
x isa Real || error("Function output is not scalar")
|
||||
isinf(x) && error("Loss is infinite")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
end
|
||||
|
||||
function gradient_nested(f, args...)
|
||||
y, back = forward(f, args...)
|
||||
losscheck(y)
|
||||
return back(1)
|
||||
end
|
||||
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
||||
# Jacobians and Hessians
|
||||
|
||||
import ..Flux
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,k,n)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[i,:] = xp.grad
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J
|
||||
end
|
||||
|
||||
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)
|
|
@ -1,53 +0,0 @@
|
|||
using ForwardDiff
|
||||
|
||||
seed(x::Real, ::Val) = Dual(x, true)
|
||||
|
||||
function seed(x, ::Val{N}, offset = 0) where N
|
||||
map(x, reshape(1:length(x), size(x))) do x, i
|
||||
Dual(x, ntuple(j -> j+offset == i, Val(N)))
|
||||
end
|
||||
end
|
||||
|
||||
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
|
||||
|
||||
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
|
||||
J = similar(xs, V, N, length(xs))
|
||||
for i = 1:length(xs), j = 1:N
|
||||
J[j, i] = xs[i].partials.values[j]
|
||||
end
|
||||
return map(x -> x.value, xs), J
|
||||
end
|
||||
|
||||
function forward_jacobian(f, x, ::Val{N}) where N
|
||||
y, _J = extract(f(seed(x, Val(N))))
|
||||
J = similar(_J, length(x), length(y))
|
||||
J[1:N,:] = _J
|
||||
offset = 0
|
||||
while offset + N < length(x)
|
||||
offset += N
|
||||
_, _J = extract(f(seed(x, Val(N), offset)))
|
||||
range = (1+offset):min(N+offset,length(x))
|
||||
J[range,:] = @view _J[range.-offset,:]
|
||||
end
|
||||
return y, J
|
||||
end
|
||||
|
||||
function forward_jacobian(f, x)
|
||||
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
|
||||
forward_jacobian(f, x, Val(length(x)))
|
||||
else
|
||||
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
|
||||
end
|
||||
end
|
||||
|
||||
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
|
||||
|
||||
vec_scalar(x) = vec(x)
|
||||
vec_scalar(x::Real) = [x]
|
||||
reshape_scalar(x, y) = reshape(y, size(x))
|
||||
reshape_scalar(x::Real, y) = y[]
|
||||
|
||||
@grad function forwarddiff(f, x)
|
||||
y, J = forward_jacobian(f, data(x))
|
||||
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
|
||||
end
|
|
@ -1,28 +0,0 @@
|
|||
struct IdSet{T} <: AbstractSet{T}
|
||||
dict::IdDict{T,Nothing}
|
||||
IdSet{T}() where T = new(IdDict{T,Nothing}())
|
||||
end
|
||||
|
||||
Base.eltype(::IdSet{T}) where T = T
|
||||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!(s::IdSet) = s
|
||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
||||
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
|
||||
|
||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||
|
||||
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
|
||||
Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
||||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
function Base.iterate(v::IdSet, state...)
|
||||
y = Base.iterate(keys(v.dict), state...)
|
||||
y === nothing && return nothing
|
||||
return (y[1], y[2])
|
||||
end
|
|
@ -1,521 +0,0 @@
|
|||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
import LinearAlgebra: inv, det, logdet, logabsdet, \, /
|
||||
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||
end
|
||||
|
||||
data(x::TrackedArray) = x.data
|
||||
tracker(x::TrackedArray) = x.tracker
|
||||
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
|
||||
|
||||
Base.convert(::Type{<:TrackedArray}, x::TrackedArray) =
|
||||
error("Not implemented: convert $(typeof(x)) to $T")
|
||||
|
||||
Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
|
||||
TrackedArray(convert(A, x))
|
||||
|
||||
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
@isdefined(A) ?
|
||||
print(io, "TrackedArray{…,$A}") :
|
||||
invoke(show, Tuple{IO,DataType}, io, t)
|
||||
|
||||
function Base.summary(io::IO, x::TrackedArray)
|
||||
print(io, "Tracked ")
|
||||
summary(io, data(x))
|
||||
end
|
||||
|
||||
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||
|
||||
function Base.show(io::IO, x::TrackedArray)
|
||||
show(io, data(x))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.copy(x::TrackedArray) = x
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
||||
|
||||
function update!(x::TrackedArray, Δ)
|
||||
x.data .+= data(Δ)
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
function update!(x::AbstractArray, Δ)
|
||||
x .+= data(Δ)
|
||||
return x
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
|
||||
size(data(x), i, j, is...)
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
for op in [:(==), :≈]
|
||||
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
Δ′[i...] = data(Δ)
|
||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||
|
||||
@grad function view(x::AbstractArray, inds...)
|
||||
view(data(x), inds...), function (Δ)
|
||||
grad_output = zero(x)
|
||||
subgrad = view(grad_output, inds...)
|
||||
subgrad[:] = data(Δ)
|
||||
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
||||
|
||||
det(xs::TrackedArray) = track(det, xs)
|
||||
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
|
||||
|
||||
logdet(xs::TrackedArray) = track(logdet, xs)
|
||||
@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),)
|
||||
|
||||
logabsdet(xs::TrackedArray) = track(logabsdet, xs)
|
||||
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)
|
||||
|
||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
S = size(xs)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
(nobacksies(:repeat, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
function combinations(xs, n)
|
||||
n < 1 && return [[]]
|
||||
cs = combinations(xs, n-1)
|
||||
[[x, c...] for x in xs, c in cs]
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
@grad function vcat(xs...)
|
||||
vcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
d = Δ[start+1:start+size(xsi,1), i...]
|
||||
start += size(xsi, 1)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
@grad function hcat(xs...)
|
||||
hcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
d = if ndims(xsi) == 1
|
||||
Δ[:, start+1]
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
Δ[:, start+1:start+size(xsi,2), i...]
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
|
||||
track(cat, $(cnames...), x, xs..., dims = dims)
|
||||
end
|
||||
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
start = ntuple(i -> 0, Val(ndims(Δ)))
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||
start = start .+ till_xs
|
||||
d
|
||||
end for xs in Xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
|
||||
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
|
||||
|
||||
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
|
||||
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
mat1_rsh = reshape(mat1,(1,m1,1,n1))
|
||||
|
||||
m2, n2 = size(mat2)
|
||||
mat2_rsh = reshape(mat2,(m2,1,n2,1))
|
||||
|
||||
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
|
||||
end
|
||||
|
||||
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||
|
||||
|
||||
inv(A::TrackedArray) = Tracker.track(inv, A)
|
||||
@grad function inv(A)
|
||||
return inv(Tracker.data(A)), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * Ainv'
|
||||
return (∇A, )
|
||||
end
|
||||
end
|
||||
|
||||
# (/) rdivide
|
||||
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
||||
@grad function (A / B)
|
||||
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
||||
Binv = inv(B)
|
||||
∇B = - Binv' * A' * Δ * Binv'
|
||||
return (Δ * Binv', ∇B)
|
||||
end
|
||||
end
|
||||
|
||||
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
||||
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
||||
@grad function (A \ B)
|
||||
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * B' * Ainv'
|
||||
return (∇A, Ainv' * Δ)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, )
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||
|
||||
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||
|
||||
import LinearAlgebra: dot
|
||||
|
||||
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected)
|
||||
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
|
||||
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))
|
||||
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||
|
||||
@grad function maximum(xs; dims = dims)
|
||||
maximum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
@grad function minimum(xs; dims = dims)
|
||||
minimum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
|
||||
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
|
||||
|
||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
|
||||
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
||||
|
||||
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
||||
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool
|
||||
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
|
||||
|
||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
|
||||
@grad depthwiseconv(x, w; kw...) =
|
||||
depthwiseconv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:depthwiseconv,
|
||||
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:conv,
|
||||
(NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x))...; size=size(w), kw...)))
|
||||
|
||||
∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
|
||||
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
|
||||
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)
|
||||
|
||||
@grad ∇conv_data(x, w; kw...) =
|
||||
∇conv_data(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:conv,
|
||||
(NNlib.conv(data.((Δ, w))...; size=size(x), kw...),
|
||||
NNlib.∇conv_filter(data.((x, Δ))...; size=size(w), kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
# Broadcasting
|
||||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue, _) = nothing
|
||||
|
||||
dual(x, p) = x
|
||||
dual(x::Real, p) = Dual(x, p)
|
||||
|
||||
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||
return Δ * f(dargs...).partials[1]
|
||||
end
|
||||
|
||||
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||
y = broadcast(f, data.(args)...)
|
||||
eltype(y) <: Real || return y
|
||||
eltype(y) == Bool && return y
|
||||
function back(Δ)
|
||||
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
||||
dxs = map(unbroadcast, args, Δargs)
|
||||
return dxs
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||
|
||||
struct TrackedStyle <: BroadcastStyle end
|
||||
|
||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||
|
||||
# We have to re-build the original broadcast struct to get the appropriate array
|
||||
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||
broadcast_rebuild(xs) = data(xs)
|
||||
|
||||
broadcast_rebuild(bc::Broadcasted) =
|
||||
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||
|
||||
preprocess(x) = x
|
||||
|
||||
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||
bc1 = Broadcast.flatten(bc)
|
||||
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||
∇broadcast(bc2.f, bc1.args...)
|
||||
end
|
||||
|
||||
using Requires
|
||||
|
||||
# https://github.com/FluxML/Flux.jl/issues/353
|
||||
if VERSION < v"1.1.0-DEV.548"
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
isflat(bc) && return bc
|
||||
args = cat_nested(bc)
|
||||
let makeargs = make_makeargs(bc), f = bc.f
|
||||
newf = @inline function(args::Vararg{Any,N}) where N
|
||||
f(makeargs(args...)...)
|
||||
end
|
||||
return Broadcasted{Style}(newf, args, bc.axes)
|
||||
end
|
||||
end
|
||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||
bc = t[1]
|
||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||
let makeargs = make_makeargs(makeargs, bc.args)
|
||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||
return @inline function(args::Vararg{Any,N}) where N
|
||||
args1 = makeargs(args...)
|
||||
a, b = headargs(args1...), tailargs(args1...)
|
||||
(f(a...), b...)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,160 +0,0 @@
|
|||
mutable struct TrackedReal{T<:Real} <: Real
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
|
||||
|
||||
data(x::TrackedReal) = x.data
|
||||
tracker(x::TrackedReal) = x.tracker
|
||||
|
||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||
|
||||
function back!(x::TrackedReal; once = true)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
return back!(x, 1, once = once)
|
||||
end
|
||||
|
||||
function update!(x::TrackedReal, Δ)
|
||||
x.data += data(Δ)
|
||||
tracker(x).grad = 0
|
||||
return x
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
T = get(io, :typeinfo, Any)
|
||||
show(io, data(x))
|
||||
T <: TrackedReal || print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
||||
|
||||
Base.copy(x::TrackedReal) = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
(T::Type{<:TrackedReal})(x::Real) = convert(T, x)
|
||||
|
||||
for op in [:(==), :≈, :<, :(<=)]
|
||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
Base.eps(::Type{TrackedReal{T}}) where T = eps(T)
|
||||
|
||||
for f in :[isinf, isnan, isfinite].args
|
||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||
end
|
||||
|
||||
Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
|
||||
|
||||
Base.float(x::TrackedReal) = x
|
||||
|
||||
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
||||
TrackedReal{promote_type(S,T)}
|
||||
|
||||
using Random
|
||||
|
||||
for f in :[rand, randn, randexp].args
|
||||
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
|
||||
end
|
||||
|
||||
using DiffRules, SpecialFunctions, NaNMath
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 1 || continue
|
||||
@eval begin
|
||||
@grad $M.$f(a::Real) =
|
||||
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
|
||||
$M.$f(a::TrackedReal) = track($M.$f, a)
|
||||
end
|
||||
end
|
||||
|
||||
# Work around zero(π) not working, for some reason
|
||||
_zero(::Irrational) = nothing
|
||||
_zero(x) = zero(x)
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 2 || continue
|
||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
||||
f = :($M.$f)
|
||||
@eval begin
|
||||
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
|
||||
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
|
||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
|
||||
# Eliminating ambiguity
|
||||
import Base:^
|
||||
|
||||
^(a::TrackedReal, b::Integer) = track(^, a, b)
|
||||
|
||||
# Hack for conversions
|
||||
|
||||
using ForwardDiff: Dual
|
||||
|
||||
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
|
||||
(Dual{T,V,N})(x::Dual) where {T,V,N} = invoke(Dual{T,V,N}, Tuple{Number}, x)
|
||||
|
||||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
data(xs::TrackedTuple) = xs.data
|
||||
tracker(xs::TrackedTuple) = xs.tracker
|
||||
|
||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
|
||||
@grad function getindex(xs::TrackedTuple, i)
|
||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, (tracker.(xs),)), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back_(c::Call{typeof(collect)}, Δ, once)
|
||||
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
|
||||
end
|
||||
|
||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||
end
|
||||
|
||||
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
|
||||
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)
|
|
@ -1,18 +0,0 @@
|
|||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zero.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp = x[i]
|
||||
x[i] = tmp - δ/2
|
||||
y1 = f(xs...)
|
||||
x[i] = tmp + δ/2
|
||||
y2 = f(xs...)
|
||||
x[i] = tmp
|
||||
Δ[i] = (y2-y1)/δ
|
||||
end
|
||||
return grads
|
||||
end
|
||||
|
||||
gradcheck(f, xs...) =
|
||||
all(isapprox.(ngradient(f, xs...),
|
||||
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))
|
|
@ -1,46 +0,0 @@
|
|||
struct Params
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
struct Grads
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
338
test/tracker.jl
338
test/tracker.jl
|
@ -1,347 +1,15 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
|
||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
|
||||
using Statistics: mean, std
|
||||
using Random
|
||||
# using StatsBase
|
||||
using Flux, Test
|
||||
using Tracker: gradcheck
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||
@testset "Tracker" begin
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
||||
@testset "Tracker" begin
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@test gradtest(det, (4, 4))
|
||||
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
|
||||
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
|
||||
|
||||
@testset "indexing & slicing" begin
|
||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||
end
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
r2 = f(A, param(B), C)
|
||||
r3 = f(A, B, param(C))
|
||||
r4 = f(param(A), param(B), param(C))
|
||||
|
||||
@test !isa(r0, TrackedArray)
|
||||
@test all(isa.([r1,r2,r3,r4], TrackedArray))
|
||||
@test r1 == r2 == r3 == r4
|
||||
@test r0 == Flux.data(r4)
|
||||
end
|
||||
|
||||
@testset "concat" begin
|
||||
cat1(x...) = cat(x..., dims = 1)
|
||||
cat2(x...) = cat(x..., dims = 2)
|
||||
|
||||
@testset for vcatf in [vcat, cat1]
|
||||
@test gradtest(vcatf, rand(5), rand(3))
|
||||
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcatf, rand(5)', rand(5)')
|
||||
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
||||
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
||||
@test gradtest(vcatf, rand(5), rand(3,1))
|
||||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||
end
|
||||
|
||||
|
||||
@testset for hcatf in [hcat, cat2]
|
||||
@test gradtest(hcatf, rand(5), rand(5))
|
||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
||||
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
||||
@test gradtest(hcatf, rand(5)', rand(1,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
@test gradtest(catf, rand(5))
|
||||
@test gradtest(catf, rand(5)')
|
||||
@test gradtest(catf, rand(2,5))
|
||||
@test gradtest(catf, rand(2,5,3))
|
||||
end
|
||||
|
||||
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 3:5
|
||||
catdim = (x...) -> cat(x..., dims = dim)
|
||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||
end
|
||||
|
||||
@test !isa(vcat(rand(2)), TrackedArray)
|
||||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
|
||||
end
|
||||
|
||||
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
||||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
|
||||
@testset "scalars" begin
|
||||
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
|
||||
@test vcat(1, param([1, 2, 3])) isa TrackedArray
|
||||
@test hcat(1, param([1 2 3;])) isa TrackedArray
|
||||
@test vcat(param(1), 2) isa TrackedArray
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
@test gradtest(kron, rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@test gradtest(x -> diagm(0 => x), rand(3))
|
||||
|
||||
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
|
||||
|
||||
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
||||
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
||||
@test gradtest(norm, rand(5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
|
||||
|
||||
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
|
||||
|
||||
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
||||
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
||||
|
||||
@test gradtest(x -> Float64.(x), 5)
|
||||
|
||||
@testset "equality & order" begin
|
||||
# TrackedReal
|
||||
@test param(2)^2 == param(4)
|
||||
@test param(2)^2 == 4
|
||||
@test 4 == param(2)^2
|
||||
|
||||
@test param(2)^2 ≈ param(4)
|
||||
@test param(2)^2 ≈ 4
|
||||
@test 4 ≈ param(2)^2
|
||||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
@test (param([1,2,3]) .<= 2) == [true, true, false]
|
||||
@test (2 .> param([1,2,3])) == [true, false, false]
|
||||
@test (2 .>= param([1,2,3])) == [true, true, false]
|
||||
|
||||
# TrackedArray
|
||||
@test param([1,2,3]).^2 == param([1,4,9])
|
||||
@test [1,2,3].^2 == param([1,4,9])
|
||||
@test param([1,2,3]).^2 == [1,4,9]
|
||||
|
||||
@test param([1,2,3]).^2 ≈ param([1,4,9])
|
||||
@test [1,2,3].^2 ≈ param([1,4,9])
|
||||
@test param([1,2,3]).^2 ≈ [1,4,9]
|
||||
end
|
||||
|
||||
@testset "reshape" begin
|
||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (4,2)
|
||||
x = reshape(param([1]), (1,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,1)
|
||||
x = reshape(param(rand(2)), (2,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (2,1)
|
||||
x = reshape(param(rand(2,2)), (1,:,2))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,2,2)
|
||||
end
|
||||
|
||||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
x.grad .= 0
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
end
|
||||
|
||||
@testset "Fallbacks" begin
|
||||
xs = param([1 2; 3 4])
|
||||
@test similar(xs) isa Matrix{Float64}
|
||||
end
|
||||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
@test Tracker.grad(b) == 1
|
||||
|
||||
@testset "collect" begin
|
||||
x, y = param(2), param(3)
|
||||
xy = Tracker.collect([x, y])
|
||||
@test xy isa TrackedArray{Float64}
|
||||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
|
||||
@test gradient(2, 3) do x, y
|
||||
xy = Tracker.collect([x, y])
|
||||
xy[1]*xy[2]
|
||||
end == (3, 2)
|
||||
end
|
||||
|
||||
# Gradient Hooks
|
||||
@testset "Hooks" begin
|
||||
x = param(2)
|
||||
y = Tracker.hook(-, x)
|
||||
back!(y)
|
||||
@test grad(x) == -1
|
||||
end
|
||||
|
||||
@testset "Checkpointing" begin
|
||||
count = 0
|
||||
function mul(a, b)
|
||||
count += 1
|
||||
a * b
|
||||
end
|
||||
@test gradient(x -> mul(5, x), 3)[1] == 5
|
||||
@test count == 1
|
||||
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
|
||||
@test count == 3
|
||||
end
|
||||
|
||||
@testset "Updates" begin
|
||||
xs = param([1, 2, 3])
|
||||
Tracker.update!(xs, param([4, 5, 6]))
|
||||
@test xs == [5, 7, 9]
|
||||
x = param(3)
|
||||
Tracker.update!(x, param(4))
|
||||
@test x == 7
|
||||
end
|
||||
|
||||
@testset "Params" begin
|
||||
W = param(randn(5, 10))
|
||||
x = rand(10)
|
||||
dW = gradient(W -> sum(W*x), W)[1]
|
||||
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
|
||||
@test gs[W] == dW
|
||||
end
|
||||
|
||||
@testset "Forward" begin
|
||||
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
|
||||
reshape(ones(25), :, 1)
|
||||
@test gradient([2, 3]) do x
|
||||
forwarddiff(x) do x
|
||||
x[1]*x[2]
|
||||
end
|
||||
end == ([3, 2],)
|
||||
end
|
||||
|
||||
@testset "Custom Sensitivities" begin
|
||||
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
|
||||
@test back([1, 1]) == (32,)
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue