Merge branch 'master' into tiny_stack_bugfix
This commit is contained in:
commit
c50ad6cdb5
|
@ -3,5 +3,4 @@
|
|||
*.jl.mem
|
||||
docs/build/
|
||||
docs/site/
|
||||
docs/flux.css
|
||||
deps
|
||||
|
|
25
.travis.yml
25
.travis.yml
|
@ -1,18 +1,29 @@
|
|||
# Documentation: http://docs.travis-ci.com/user/languages/julia/
|
||||
language: julia
|
||||
|
||||
os:
|
||||
- linux
|
||||
# - osx
|
||||
|
||||
julia:
|
||||
- 1.0
|
||||
- nightly
|
||||
# uncomment the following lines to override the default test script
|
||||
# script:
|
||||
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- julia: nightly
|
||||
after_success:
|
||||
- julia -e 'using Pkg; ps=Pkg.PackageSpec(name="Documenter", version="0.19"); Pkg.add(ps); Pkg.pin(ps); Pkg.add("NNlib")'
|
||||
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||
|
||||
jobs:
|
||||
include:
|
||||
- stage: "Documentation"
|
||||
julia: 1.0
|
||||
os: linux
|
||||
script:
|
||||
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
|
||||
Pkg.instantiate()'
|
||||
- julia --project=docs/ docs/make.jl
|
||||
after_success: skip
|
||||
|
||||
## uncomment the following lines to override the default test script
|
||||
script:
|
||||
- julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()'
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
# This file is machine-generated - editing it directly is not advised
|
||||
|
||||
[[AbstractTrees]]
|
||||
deps = ["Markdown", "Test"]
|
||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||
|
@ -8,9 +6,9 @@ version = "0.2.1"
|
|||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
|
||||
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
@ -53,15 +51,15 @@ version = "0.2.0"
|
|||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
|
||||
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "1.4.0"
|
||||
version = "1.5.1"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "8fc6e166e24fda04b2b648d4260cdad241788c54"
|
||||
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.14.0"
|
||||
version = "0.15.0"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
|
@ -79,12 +77,12 @@ version = "0.0.3"
|
|||
|
||||
[[DiffRules]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
|
||||
git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "0.0.7"
|
||||
version = "0.0.8"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
|
@ -95,19 +93,19 @@ version = "0.5.3"
|
|||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
|
||||
git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.1"
|
||||
version = "0.10.2"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
|
||||
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.5.3"
|
||||
version = "0.5.4"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
@ -140,18 +138,20 @@ version = "0.5.0"
|
|||
|
||||
[[Missings]]
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
|
||||
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.3.1"
|
||||
version = "0.4.0"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
|
||||
git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/FluxML/NNlib.jl.git"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.4.3"
|
||||
version = "0.4.3+"
|
||||
|
||||
[[NaNMath]]
|
||||
deps = ["Compat"]
|
||||
|
@ -228,19 +228,19 @@ version = "0.7.2"
|
|||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "97c4bf0f647488dd7ac01ea12be5885f88762938"
|
||||
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.10.0"
|
||||
version = "0.10.2"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "2722397d88f8ffef551948f6c20e1d74a743298c"
|
||||
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.26.0"
|
||||
version = "0.27.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
|
@ -259,7 +259,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
|||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random", "SHA"]
|
||||
deps = ["Random"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
|
|
|
@ -13,10 +13,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
|||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
|
|
|
@ -0,0 +1,288 @@
|
|||
[[AbstractTrees]]
|
||||
deps = ["Markdown", "Test"]
|
||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
version = "0.2.1"
|
||||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "0.4.1"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
||||
[[BinDeps]]
|
||||
deps = ["Compat", "Libdl", "SHA", "URIParser"]
|
||||
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
|
||||
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
|
||||
version = "0.8.10"
|
||||
|
||||
[[BinaryProvider]]
|
||||
deps = ["Libdl", "Pkg", "SHA", "Test"]
|
||||
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.3"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.5.1"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
|
||||
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
||||
version = "0.7.5"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
|
||||
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.9.5"
|
||||
|
||||
[[CommonSubexpressions]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
|
||||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "1.4.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.15.0"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||
|
||||
[[DelimitedFiles]]
|
||||
deps = ["Mmap"]
|
||||
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
|
||||
[[DiffResults]]
|
||||
deps = ["Compat", "StaticArrays"]
|
||||
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
|
||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "0.0.3"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "0.0.7"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[DocStringExtensions]]
|
||||
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
|
||||
git-tree-sha1 = "1df01539a1c952cef21f2d2d1c092c2bcf0177d7"
|
||||
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
|
||||
version = "0.6.0"
|
||||
|
||||
[[Documenter]]
|
||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
||||
git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617"
|
||||
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
version = "0.21.0"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
|
||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
||||
version = "0.5.3"
|
||||
|
||||
[[Flux]]
|
||||
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
|
||||
path = ".."
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.6.10+"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.1"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.5.3"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
||||
[[Libdl]]
|
||||
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
|
||||
|
||||
[[LinearAlgebra]]
|
||||
deps = ["Libdl"]
|
||||
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
|
||||
[[Logging]]
|
||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
|
||||
[[MacroTools]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.4.4"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
||||
|
||||
[[Media]]
|
||||
deps = ["MacroTools", "Test"]
|
||||
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
|
||||
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Missings]]
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.3.1"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.4.3"
|
||||
|
||||
[[NaNMath]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
|
||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
version = "0.3.2"
|
||||
|
||||
[[OrderedCollections]]
|
||||
deps = ["Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.0.2"
|
||||
|
||||
[[Pkg]]
|
||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
|
||||
[[Printf]]
|
||||
deps = ["Unicode"]
|
||||
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
|
||||
[[Profile]]
|
||||
deps = ["Printf"]
|
||||
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
|
||||
|
||||
[[REPL]]
|
||||
deps = ["InteractiveUtils", "Markdown", "Sockets"]
|
||||
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
||||
|
||||
[[Random]]
|
||||
deps = ["Serialization"]
|
||||
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
|
||||
[[Reexport]]
|
||||
deps = ["Pkg"]
|
||||
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
|
||||
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Requires]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
|
||||
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
version = "0.5.2"
|
||||
|
||||
[[SHA]]
|
||||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
|
||||
[[Serialization]]
|
||||
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
||||
|
||||
[[SharedArrays]]
|
||||
deps = ["Distributed", "Mmap", "Random", "Serialization"]
|
||||
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
|
||||
|
||||
[[Sockets]]
|
||||
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
|
||||
|
||||
[[SortingAlgorithms]]
|
||||
deps = ["DataStructures", "Random", "Test"]
|
||||
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
|
||||
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
|
||||
version = "0.3.1"
|
||||
|
||||
[[SparseArrays]]
|
||||
deps = ["LinearAlgebra", "Random"]
|
||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
||||
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.7.2"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.10.2"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.27.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.8.1"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
|
||||
uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
||||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||
|
||||
[[ZipFile]]
|
||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.8.0"
|
|
@ -0,0 +1,4 @@
|
|||
[deps]
|
||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
15
docs/make.jl
15
docs/make.jl
|
@ -1,11 +1,12 @@
|
|||
using Documenter, Flux, NNlib
|
||||
|
||||
makedocs(modules=[Flux, NNlib],
|
||||
doctest = false,
|
||||
format = :html,
|
||||
doctest = true,
|
||||
analytics = "UA-36890222-9",
|
||||
sitename = "Flux",
|
||||
assets = ["../flux.css"],
|
||||
# Uncomment below for local build
|
||||
#format = Documenter.HTML(prettyurls = false),
|
||||
assets = ["assets/flux.css"],
|
||||
pages = ["Home" => "index.md",
|
||||
"Building Models" =>
|
||||
["Basics" => "models/basics.md",
|
||||
|
@ -22,10 +23,4 @@ makedocs(modules=[Flux, NNlib],
|
|||
["Backpropagation" => "internals/tracker.md"],
|
||||
"Community" => "community.md"])
|
||||
|
||||
deploydocs(
|
||||
repo = "github.com/FluxML/Flux.jl.git",
|
||||
target = "build",
|
||||
osname = "linux",
|
||||
julia = "1.0",
|
||||
deps = nothing,
|
||||
make = nothing)
|
||||
deploydocs(repo = "github.com/FluxML/Flux.jl.git")
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');
|
||||
|
||||
body {
|
||||
font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
|
||||
}
|
||||
|
||||
nav.toc {
|
||||
padding-top: 0;
|
||||
background: rgb(240, 240, 240);
|
||||
line-height: 2em;
|
||||
cursor: default;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
h1+h2 {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* Green banner in ToC */
|
||||
nav.toc > h1 {
|
||||
margin-top: 0;
|
||||
padding-top: 0.4em;
|
||||
padding-bottom: 0.5em;
|
||||
border-bottom: 5px solid white;
|
||||
box-shadow: 0px -2px 5px rgb(60,60,60);
|
||||
margin-bottom: 0.5em;
|
||||
background: rgb(60, 150, 60);
|
||||
|
||||
font-style: italic;
|
||||
font-weight: normal;
|
||||
font-size: 50pt;
|
||||
text-transform: lowercase;
|
||||
text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
|
||||
color: white;
|
||||
}
|
||||
|
||||
/* Reduce ToC font size */
|
||||
.toctext {
|
||||
font-size: 10pt;
|
||||
}
|
||||
|
||||
/* Fade out non-clickable ToC headers */
|
||||
nav.toc ul span.toctext {
|
||||
color: rgb(180, 180, 180);
|
||||
}
|
||||
|
||||
nav.toc ul .toctext {
|
||||
color: rgb(100, 100, 100);
|
||||
}
|
||||
|
||||
nav.toc ul a.toctext:hover {
|
||||
color: inherit;
|
||||
background: rgb(220, 220, 220);
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
nav.toc li.current > .toctext {
|
||||
background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
nav.toc ul.internal li.toplevel {
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
/* Content */
|
||||
|
||||
article { max-width: none; }
|
||||
|
||||
article > p, article > ul {
|
||||
max-width: 45em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
a, a:visited { color: rgb(0, 120, 0); }
|
||||
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
|
||||
a:hover, a:visited:hover { color: rgb(0, 80, 0); }
|
||||
|
||||
/* Article Links */
|
||||
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
|
||||
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
|
||||
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }
|
||||
|
||||
/* Doctstrings */
|
||||
article section.docstring {
|
||||
padding: 0.5em 0;
|
||||
border-left: none;
|
||||
border-right: none;
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
/* Code */
|
||||
|
||||
article pre, article p > code {
|
||||
background: rgb(245, 250, 245);
|
||||
}
|
||||
|
||||
article pre {
|
||||
border: none;
|
||||
max-width: none;
|
||||
padding: 1em;
|
||||
border-radius: 10px 0px 0px 10px;
|
||||
margin-left: -1em;
|
||||
margin-right: -2em;
|
||||
}
|
||||
|
||||
.hljs-comment {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.hljs-number {
|
||||
color: rgb(0, 150, 150);
|
||||
}
|
|
@ -1,10 +1,22 @@
|
|||
# GPU Support
|
||||
|
||||
## Installation
|
||||
|
||||
To get GPU support for NVIDIA graphics cards, you need to install `CuArrays.jl`
|
||||
|
||||
**Steps needed**
|
||||
|
||||
1. Install [NVIDIA toolkit](https://developer.nvidia.com/cuda-downloads)
|
||||
2. Install [NVIDIA cuDNN library](https://developer.nvidia.com/cudnn)
|
||||
3. In Julia's terminal run `]add CuArrays`
|
||||
|
||||
## GPU Usage
|
||||
|
||||
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
|
||||
|
||||
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
||||
|
||||
(Note that you need to have CUDA available to use CuArrays – please see the [CuArrays.jl (https://github.com/JuliaGPU/CuArrays.jl) instructions for more details.)
|
||||
(Note that you need to have CUDA available to use CuArrays – please see the [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl) instructions for more details.)
|
||||
|
||||
```julia
|
||||
using CuArrays
|
||||
|
|
|
@ -4,49 +4,56 @@
|
|||
|
||||
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
```jldoctest basics
|
||||
julia> using Flux.Tracker
|
||||
|
||||
f(x) = 3x^2 + 2x + 1
|
||||
julia> f(x) = 3x^2 + 2x + 1;
|
||||
|
||||
# df/dx = 6x + 2
|
||||
df(x) = Tracker.gradient(f, x)[1]
|
||||
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
|
||||
|
||||
df(2) # 14.0 (tracked)
|
||||
julia> df(2)
|
||||
14.0 (tracked)
|
||||
|
||||
# d²f/dx² = 6
|
||||
d2f(x) = Tracker.gradient(df, x)[1]
|
||||
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
|
||||
|
||||
d2f(2) # 6.0 (tracked)
|
||||
julia> d2f(2)
|
||||
6.0 (tracked)
|
||||
```
|
||||
|
||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||
|
||||
When a function has many parameters, we can pass them all in explicitly:
|
||||
|
||||
```julia
|
||||
f(W, b, x) = W * x + b
|
||||
```jldoctest basics
|
||||
julia> f(W, b, x) = W * x + b;
|
||||
|
||||
Tracker.gradient(f, 2, 3, 4)
|
||||
(4.0 (tracked), 1.0, 2.0 (tracked))
|
||||
julia> Tracker.gradient(f, 2, 3, 4)
|
||||
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
|
||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
||||
|
||||
```julia
|
||||
W = param(2) # 2.0 (tracked)
|
||||
b = param(3) # 3.0 (tracked)
|
||||
```jldoctest basics
|
||||
julia> using Flux
|
||||
|
||||
f(x) = W * x + b
|
||||
julia> W = param(2)
|
||||
2.0 (tracked)
|
||||
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> f(4), params)
|
||||
julia> b = param(3)
|
||||
3.0 (tracked)
|
||||
|
||||
grads[W] # 4.0
|
||||
grads[b] # 1.0
|
||||
julia> f(x) = W * x + b;
|
||||
|
||||
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
|
||||
|
||||
julia> grads[W]
|
||||
4.0
|
||||
|
||||
julia> grads[b]
|
||||
1.0
|
||||
```
|
||||
|
||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
|
||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||
|
||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||
|
||||
|
@ -77,7 +84,7 @@ using Flux.Tracker
|
|||
W = param(W)
|
||||
b = param(b)
|
||||
|
||||
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
|
||||
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
|
||||
```
|
||||
|
||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
||||
|
@ -102,6 +109,8 @@ All deep learning in Flux, however complex, is a simple generalisation of this e
|
|||
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as:
|
||||
|
||||
```julia
|
||||
using Flux
|
||||
|
||||
W1 = param(rand(3, 5))
|
||||
b1 = param(rand(3))
|
||||
layer1(x) = W1 * x .+ b1
|
||||
|
|
|
@ -14,6 +14,7 @@ MeanPool
|
|||
|
||||
```@docs
|
||||
DepthwiseConv
|
||||
ConvTranspose
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
using Flux, Flux.Tracker
|
||||
|
||||
W = param(rand(2, 5))
|
||||
b = param(rand(2))
|
||||
|
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
|
|||
x, y = rand(5), rand(2) # Dummy data
|
||||
l = loss(x, y) # ~ 3
|
||||
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> loss(x, y), params)
|
||||
θ = Params([W, b])
|
||||
grads = Tracker.gradient(() -> loss(x, y), θ)
|
||||
```
|
||||
|
||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
@ -23,44 +23,30 @@ We want to update each parameter, using the gradient, in order to improve (reduc
|
|||
```julia
|
||||
using Flux.Tracker: grad, update!
|
||||
|
||||
function sgd()
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
update!(p, -η * grads[p])
|
||||
end
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
update!(p, -η * grads[p])
|
||||
end
|
||||
```
|
||||
|
||||
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
|
||||
|
||||
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
|
||||
|
||||
In this case, getting the variables is trivial, but you can imagine it'd be more of a pain with some complex stack of layers.
|
||||
Running this will alter the parameters `W` and `b` and our loss should go down. Flux provides a more general way to do optimiser updates like this.
|
||||
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(10, 5, σ),
|
||||
Dense(5, 2), softmax)
|
||||
opt = Descent(0.1) # Gradient descent with learning rate 0.1
|
||||
|
||||
for p in (W, b)
|
||||
update!(opt, p, grads[p])
|
||||
end
|
||||
```
|
||||
|
||||
Instead of having to write `[m[1].W, m[1].b, ...]`, Flux provides a params function `params(m)` that returns a list of all parameters in the model for you.
|
||||
|
||||
For the update step, there's nothing whatsoever wrong with writing the loop above – it'll work just fine – but Flux provides various *optimisers* that make it more convenient.
|
||||
|
||||
```julia
|
||||
opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
|
||||
|
||||
opt() # Carry out the update, modifying `W` and `b`.
|
||||
```
|
||||
|
||||
An optimiser takes a parameter list and returns a function that does the same thing as `update` above. We can pass either `opt` or `update` to our [training loop](training.md), which will then run the optimiser after every mini-batch of data.
|
||||
An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `ADAM`.
|
||||
|
||||
## Optimiser Reference
|
||||
|
||||
All optimisers return a function that, when called, will update the parameters passed to it.
|
||||
All optimisers return an object that, when passed to `train!`, will update the parameters passed to it.
|
||||
|
||||
```@docs
|
||||
SGD
|
||||
Descent
|
||||
Momentum
|
||||
Nesterov
|
||||
ADAM
|
||||
|
|
|
@ -9,7 +9,7 @@ To actually train a model we need three things:
|
|||
With these we can call `Flux.train!`:
|
||||
|
||||
```julia
|
||||
Flux.train!(objective, data, opt)
|
||||
Flux.train!(objective, params, data, opt)
|
||||
```
|
||||
|
||||
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
|
||||
|
@ -24,9 +24,10 @@ m = Chain(
|
|||
Dense(32, 10), softmax)
|
||||
|
||||
loss(x, y) = Flux.mse(m(x), y)
|
||||
ps = Flux.params(m)
|
||||
|
||||
# later
|
||||
Flux.train!(loss, data, opt)
|
||||
Flux.train!(loss, ps, data, opt)
|
||||
```
|
||||
|
||||
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||
|
@ -78,7 +79,7 @@ julia> @epochs 2 Flux.train!(...)
|
|||
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
||||
|
||||
```julia
|
||||
train!(objective, data, opt, cb = () -> println("training"))
|
||||
train!(objective, ps, data, opt, cb = () -> println("training"))
|
||||
```
|
||||
|
||||
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
|
||||
|
@ -89,6 +90,6 @@ A more typical callback might look like this:
|
|||
test_x, test_y = # ... create single batch of test data ...
|
||||
evalcb() = @show(loss(test_x, test_y))
|
||||
|
||||
Flux.train!(objective, data, opt,
|
||||
Flux.train!(objective, ps, data, opt,
|
||||
cb = throttle(evalcb, 5))
|
||||
```
|
||||
|
|
|
@ -6,9 +6,9 @@ using Base: tail
|
|||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
||||
|
|
|
@ -1,6 +1,22 @@
|
|||
module CUDA
|
||||
|
||||
using ..CuArrays
|
||||
using Pkg.TOML
|
||||
|
||||
function version_check()
|
||||
minor_version = 9
|
||||
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
|
||||
project = TOML.parse(String(read(project)))
|
||||
version = VersionNumber(get(project, "version", "0.0.0"))
|
||||
if !(version.major == 0 && version.minor == minor_version)
|
||||
@warn """
|
||||
Flux is only supported with CuArrays v0.$minor_version.
|
||||
Try running `] pin CuArrays@0.$minor_version`.
|
||||
"""
|
||||
end
|
||||
end
|
||||
|
||||
version_check()
|
||||
|
||||
if !applicable(CuArray{UInt8}, undef, 1)
|
||||
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
|
||||
|
|
|
@ -1,11 +1,27 @@
|
|||
module Data
|
||||
|
||||
import ..Flux
|
||||
import SHA
|
||||
|
||||
export CMUDict, cmudict
|
||||
|
||||
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
|
||||
|
||||
function download_and_verify(url, path, hash)
|
||||
tmppath = tempname()
|
||||
download(url, tmppath)
|
||||
hash_download = open(tmppath) do f
|
||||
bytes2hex(SHA.sha256(f))
|
||||
end
|
||||
if hash_download !== hash
|
||||
msg = "Hash Mismatch!\n"
|
||||
msg *= " Expected sha256: $hash\n"
|
||||
msg *= " Calculated sha256: $hash_download"
|
||||
error(msg)
|
||||
end
|
||||
mv(tmppath, path; force=true)
|
||||
end
|
||||
|
||||
function __init__()
|
||||
mkpath(deps())
|
||||
end
|
||||
|
|
|
@ -2,23 +2,25 @@ module CMUDict
|
|||
|
||||
export cmudict
|
||||
|
||||
using ..Data: deps
|
||||
using ..Data: deps, download_and_verify
|
||||
|
||||
const version = "0.7b"
|
||||
const cache_prefix = "https://cache.julialang.org"
|
||||
|
||||
function load()
|
||||
suffixes = ["", ".phones", ".symbols"]
|
||||
suffixes_and_hashes = [("" , "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"),
|
||||
(".phones" , "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"),
|
||||
(".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027")]
|
||||
if isdir(deps("cmudict"))
|
||||
if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes)
|
||||
if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes)
|
||||
return
|
||||
end
|
||||
end
|
||||
@info "Downloading CMUDict dataset"
|
||||
mkpath(deps("cmudict"))
|
||||
for x in suffixes
|
||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||
deps("cmudict", "cmudict$x"))
|
||||
for (x, hash) in suffixes_and_hashes
|
||||
download_and_verify("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||
deps("cmudict", "cmudict$x"), hash)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
module FashionMNIST
|
||||
|
||||
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
|
||||
using ..Data: download_and_verify
|
||||
|
||||
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
|
||||
|
||||
function load()
|
||||
mkpath(dir)
|
||||
cd(dir) do
|
||||
for file in ["train-images-idx3-ubyte",
|
||||
"train-labels-idx1-ubyte",
|
||||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
for (file, hash) in [("train-images-idx3-ubyte", "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"),
|
||||
("train-labels-idx1-ubyte", "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"),
|
||||
("t10k-images-idx3-ubyte" , "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"),
|
||||
("t10k-labels-idx1-ubyte" , "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5")]
|
||||
isfile(file) && continue
|
||||
@info "Downloading Fashion-MNIST dataset"
|
||||
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz")
|
||||
download_and_verify("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz", hash)
|
||||
open(file, "w") do io
|
||||
write(io, gzopen(read, "$file.gz"))
|
||||
end
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
module MNIST
|
||||
|
||||
using CodecZlib, Colors
|
||||
using ..Data: download_and_verify
|
||||
|
||||
const Gray = Colors.Gray{Colors.N0f8}
|
||||
|
||||
|
@ -15,13 +16,13 @@ end
|
|||
function load()
|
||||
mkpath(dir)
|
||||
cd(dir) do
|
||||
for file in ["train-images-idx3-ubyte",
|
||||
"train-labels-idx1-ubyte",
|
||||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
for (file, hash) in [("train-images-idx3-ubyte", "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"),
|
||||
("train-labels-idx1-ubyte", "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"),
|
||||
("t10k-images-idx3-ubyte" , "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"),
|
||||
("t10k-labels-idx1-ubyte" , "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6")]
|
||||
isfile(file) && continue
|
||||
@info "Downloading MNIST dataset"
|
||||
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||
download_and_verify("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz", hash)
|
||||
open(file, "w") do io
|
||||
write(io, gzopen(read, "$file.gz"))
|
||||
end
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
module Sentiment
|
||||
|
||||
using ZipFile
|
||||
using ..Data: deps
|
||||
using ..Data: deps, download_and_verify
|
||||
|
||||
function load()
|
||||
isfile(deps("sentiment.zip")) && return
|
||||
@info "Downloading sentiment treebank dataset"
|
||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
download_and_verify("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"), "5c613a4f673fc74097d523a2c83f38e0cc462984d847b82c7aaf36b01cbbbfcc")
|
||||
end
|
||||
|
||||
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
||||
|
|
|
@ -21,8 +21,8 @@ struct Chain{T<:Tuple}
|
|||
Chain(xs...) = new{typeof(xs)}(xs)
|
||||
end
|
||||
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex
|
||||
@forward Chain.layers Base.iterate
|
||||
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||
Base.iterate, Base.lastindex
|
||||
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using NNlib: conv, depthwiseconv
|
||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||
|
||||
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||
|
||||
|
@ -13,7 +13,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
|||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
|
@ -57,6 +57,54 @@ end
|
|||
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
ConvTranspose(size, in=>out)
|
||||
ConvTranspose(size, in=>out, relu)
|
||||
|
||||
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct ConvTranspose{N,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
|
||||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike ConvTranspose
|
||||
|
||||
function (c::ConvTranspose)(x::AbstractArray)
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::ConvTranspose)
|
||||
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
|
||||
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
"""
|
||||
DepthwiseConv(size, in)
|
||||
DepthwiseConv(size, in=>mul)
|
||||
|
@ -83,12 +131,12 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
stride = 1, pad = 0) where {T,N} =
|
||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0) where N =
|
||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||
|
|
|
@ -106,7 +106,7 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
|
@ -138,7 +138,9 @@ function (BN::BatchNorm)(x)
|
|||
end
|
||||
|
||||
let λ = BN.λ
|
||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
|
||||
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
|
||||
# This is intentionally not fused because of an extreme slowdown doing so
|
||||
λ.(temp .+ reshape(β, affine_shape...))
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ end
|
|||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zeros(out)), param(init(out)))
|
||||
param(init(out)), param(zeros(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
|
@ -122,8 +122,8 @@ end
|
|||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(init(out)), param(init(out)))
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
|
||||
param(zeros(out)), param(zeros(out)))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
|
@ -169,7 +169,7 @@ end
|
|||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(init(out)))
|
||||
param(init(out*3)), param(zeros(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
|
|
@ -2,16 +2,14 @@ using NNlib: logsoftmax, logσ
|
|||
|
||||
# Cost functions
|
||||
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
-sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -42,12 +40,17 @@ but it is more numerically stable.
|
|||
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
|
||||
"""
|
||||
normalise(x::AbstractVecOrMat)
|
||||
normalise(x::AbstractArray; dims=1)
|
||||
|
||||
Normalise each column of `x` to mean 0 and standard deviation 1.
|
||||
Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
|
||||
"""
|
||||
function normalise(x::AbstractVecOrMat)
|
||||
μ′ = mean(x, dims = 1)
|
||||
σ′ = std(x, dims = 1, mean = μ′)
|
||||
function normalise(x::AbstractArray; dims=1)
|
||||
μ′ = mean(x, dims = dims)
|
||||
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
||||
return (x .- μ′) ./ σ′
|
||||
end
|
||||
|
||||
function normalise(x::AbstractArray, dims)
|
||||
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
||||
normalise(x, dims = dims)
|
||||
end
|
||||
|
|
|
@ -68,3 +68,6 @@ end
|
|||
|
||||
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
|
||||
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
|
||||
|
||||
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||
|
|
|
@ -4,7 +4,7 @@ using Flux: Params
|
|||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||
|
||||
# legacy update rule
|
||||
updaterule(opt, ps) = () -> update!(opt, ps)
|
||||
updaterule(opt, ps) = () -> _update_params!(opt, ps)
|
||||
|
||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||
|
@ -117,7 +117,7 @@ struct OldOptimiser
|
|||
func
|
||||
end
|
||||
|
||||
update!(opt::OldOptimiser, ps) = opt.func()
|
||||
_update_params!(opt::OldOptimiser, ps) = opt.func()
|
||||
|
||||
# Train function
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
|
|
|
@ -18,7 +18,7 @@ end
|
|||
|
||||
Descent() = Descent(0.1)
|
||||
|
||||
function update!(o::Descent, x, Δ)
|
||||
function apply!(o::Descent, x, Δ)
|
||||
Δ .*= o.eta
|
||||
end
|
||||
|
||||
|
@ -35,7 +35,7 @@ end
|
|||
|
||||
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||
|
||||
function update!(o::Momentum, x, Δ)
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
|
@ -55,7 +55,7 @@ end
|
|||
|
||||
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||
|
||||
function update!(o::Nesterov, x, Δ)
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
|
@ -78,7 +78,7 @@ end
|
|||
|
||||
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||
|
||||
function update!(o::RMSProp, x, Δ)
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
|
@ -98,7 +98,7 @@ end
|
|||
|
||||
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
||||
|
||||
function update!(o::ADAM, x, Δ)
|
||||
function apply!(o::ADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
|
@ -122,7 +122,7 @@ end
|
|||
|
||||
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
||||
|
||||
function update!(o::AdaMax, x, Δ)
|
||||
function apply!(o::AdaMax, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
|
@ -145,7 +145,7 @@ end
|
|||
|
||||
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function update!(o::ADAGrad, x, Δ)
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
|
@ -165,7 +165,7 @@ end
|
|||
|
||||
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
||||
|
||||
function update!(o::ADADelta, x, Δ)
|
||||
function apply!(o::ADADelta, x, Δ)
|
||||
ρ = o.rho
|
||||
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
|
@ -188,7 +188,7 @@ end
|
|||
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||
|
||||
function update!(o::AMSGrad, x, Δ)
|
||||
function apply!(o::AMSGrad, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
|
@ -211,7 +211,7 @@ end
|
|||
|
||||
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||
|
||||
function update!(o::NADAM, x, Δ)
|
||||
function apply!(o::NADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
β1p, β2p = o.beta
|
||||
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||
|
@ -228,7 +228,7 @@ end
|
|||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||
Optimiser(ADAM(η, β), WeightDecay(wd))
|
||||
Optimiser(ADAM(η, β), WeightDecay(decay))
|
||||
|
||||
# Compose optimizers
|
||||
|
||||
|
@ -250,13 +250,21 @@ Optimiser(o...) = Optimiser(Any[o...])
|
|||
|
||||
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
||||
|
||||
function update!(o::Optimiser, x, Δ)
|
||||
function apply!(o::Optimiser, x, Δ)
|
||||
for opt in o.os
|
||||
Δ = update!(opt, x, Δ)
|
||||
Δ = apply!(opt, x, Δ)
|
||||
end
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
`InvDecay(γ)`
|
||||
|
||||
Apply inverse time decay to an optimiser
|
||||
```julia
|
||||
Optimiser(InvDecay(..), Opt(..))
|
||||
```
|
||||
"""
|
||||
mutable struct InvDecay
|
||||
gamma::Float64
|
||||
state::IdDict
|
||||
|
@ -264,7 +272,7 @@ end
|
|||
|
||||
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||
|
||||
function update!(o::InvDecay, x, Δ)
|
||||
function apply!(o::InvDecay, x, Δ)
|
||||
γ = o.gamma
|
||||
n = get!(o.state, x, 1)
|
||||
Δ .*= 1 / (1 + γ * n)
|
||||
|
@ -272,6 +280,16 @@ function update!(o::InvDecay, x, Δ)
|
|||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
`ExpDecay(eta, decay, decay_step, clip)`
|
||||
|
||||
Schedule the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
|
||||
|
||||
To apply exponential decay to an optimiser:
|
||||
```julia
|
||||
Optimiser(ExpDecay(..), Opt(..))
|
||||
```
|
||||
"""
|
||||
mutable struct ExpDecay
|
||||
eta::Float64
|
||||
decay::Float64
|
||||
|
@ -282,7 +300,7 @@ end
|
|||
|
||||
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
function apply!(o::ExpDecay, x, Δ)
|
||||
η, s, decay = o.eta, o.step, o.decay
|
||||
n = o.current[x] = get(o.current, x, 0) + 1
|
||||
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
||||
|
@ -292,13 +310,18 @@ function update!(o::ExpDecay, x, Δ)
|
|||
@. Δ *= decay
|
||||
end
|
||||
|
||||
"""
|
||||
`WeightDecay(wd)`
|
||||
|
||||
Decay the weight parameter by `wd`
|
||||
"""
|
||||
mutable struct WeightDecay
|
||||
wd::Real
|
||||
end
|
||||
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function update!(o::WeightDecay, x, Δ)
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
using Juno
|
||||
using Flux.Tracker: data, grad, back!
|
||||
import Flux.Tracker: data, grad, back!, update!
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, xs)
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, apply!(opt, x, copy(data(x̄))))
|
||||
end
|
||||
|
||||
function _update_params!(opt, xs)
|
||||
for x in xs
|
||||
Δ = update!(opt, x.data, x.grad)
|
||||
Δ = apply!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
end
|
||||
|
@ -45,7 +49,7 @@ function stop()
|
|||
end
|
||||
|
||||
"""
|
||||
train!(model, loss, data, opt)
|
||||
train!(loss, params, data, opt; cb)
|
||||
|
||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||
backpropagation and calls the optimizer `opt`.
|
||||
|
@ -54,11 +58,11 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
|
|||
every 10 seconds:
|
||||
|
||||
```julia
|
||||
Flux.train!(model, loss, data, opt,
|
||||
Flux.train!(loss, params, data, opt,
|
||||
cb = throttle(() -> println("training"), 10))
|
||||
```
|
||||
|
||||
The callback can return `:stop` to interrupt the training loop.
|
||||
The callback can call `Flux.stop()` to interrupt the training loop.
|
||||
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
|
@ -69,7 +73,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
update!(opt, ps)
|
||||
_update_params!(opt, ps)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
|
|
@ -6,7 +6,7 @@ using MacroTools: @q, @forward
|
|||
import Base: ==
|
||||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
||||
param, back!
|
||||
jacobian, hessian, param, back!
|
||||
|
||||
tracker(x) = nothing
|
||||
|
||||
|
@ -61,24 +61,20 @@ macro grad(ex)
|
|||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||
end
|
||||
|
||||
function update!(x, Δ)
|
||||
x.data .+= data(Δ)
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("back.jl")
|
||||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
include("lib/array.jl")
|
||||
include("forward.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`."""
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ function back!(x, Δ; once = true)
|
|||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
xs = param.(data.(xs))
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
|
@ -147,8 +147,10 @@ end
|
|||
|
||||
back(::Grads, ::Nothing, _) = return
|
||||
|
||||
collectmemaybe(xs) = xs
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
y = collectmemaybe(f())
|
||||
y, function (Δ)
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
|
@ -179,3 +181,30 @@ end
|
|||
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
||||
|
||||
# Jacobians and Hessians
|
||||
|
||||
import ..Flux
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,k,n)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[i,:] = xp.grad
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J
|
||||
end
|
||||
|
||||
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
using ForwardDiff
|
||||
|
||||
seed(x::Real, ::Val) = Dual(x, true)
|
||||
|
||||
function seed(x, ::Val{N}, offset = 0) where N
|
||||
map(x, reshape(1:length(x), size(x))) do x, i
|
||||
Dual(x, ntuple(j -> j+offset == i, Val(N)))
|
||||
end
|
||||
end
|
||||
|
||||
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
|
||||
|
||||
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
|
||||
J = similar(xs, V, N, length(xs))
|
||||
for i = 1:length(xs), j = 1:N
|
||||
J[j, i] = xs[i].partials.values[j]
|
||||
end
|
||||
return map(x -> x.value, xs), J
|
||||
end
|
||||
|
||||
function forward_jacobian(f, x, ::Val{N}) where N
|
||||
y, _J = extract(f(seed(x, Val(N))))
|
||||
J = similar(_J, length(x), length(y))
|
||||
J[1:N,:] = _J
|
||||
offset = 0
|
||||
while offset + N < length(x)
|
||||
offset += N
|
||||
_, _J = extract(f(seed(x, Val(N), offset)))
|
||||
range = (1+offset):min(N+offset,length(x))
|
||||
J[range,:] = @view _J[range.-offset,:]
|
||||
end
|
||||
return y, J
|
||||
end
|
||||
|
||||
function forward_jacobian(f, x)
|
||||
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
|
||||
forward_jacobian(f, x, Val(length(x)))
|
||||
else
|
||||
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
|
||||
end
|
||||
end
|
||||
|
||||
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
|
||||
|
||||
vec_scalar(x) = vec(x)
|
||||
vec_scalar(x::Real) = [x]
|
||||
reshape_scalar(x, y) = reshape(y, size(x))
|
||||
reshape_scalar(x::Real, y) = y[]
|
||||
|
||||
@grad function forwarddiff(f, x)
|
||||
y, J = forward_jacobian(f, data(x))
|
||||
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
|
||||
end
|
|
@ -1,7 +1,7 @@
|
|||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
import LinearAlgebra: inv, \, /
|
||||
import LinearAlgebra: inv, det, logdet, logabsdet, \, /
|
||||
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
@ -65,6 +65,12 @@ Base.setindex!(xs::TrackedArray, v, i...) =
|
|||
|
||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
||||
|
||||
function update!(x::TrackedArray, Δ)
|
||||
x.data .+= data(Δ)
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
|
@ -115,8 +121,17 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
|||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
||||
|
||||
det(xs::TrackedArray) = track(det, xs)
|
||||
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
|
||||
|
||||
logdet(xs::TrackedArray) = track(logdet, xs)
|
||||
@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),)
|
||||
|
||||
logabsdet(xs::TrackedArray) = track(logabsdet, xs)
|
||||
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)
|
||||
|
||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||
|
||||
|
@ -142,11 +157,9 @@ function combinations(xs, n)
|
|||
[[x, c...] for x in xs, c in cs]
|
||||
end
|
||||
|
||||
combinations([AbstractArray, TrackedArray], 2)
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
|
||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
|
@ -219,8 +232,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs,
|
|||
|
||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
|
||||
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
|
||||
|
||||
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
|
||||
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
|
@ -305,9 +321,9 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
|||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected)
|
||||
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
|
||||
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))
|
||||
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
@ -357,7 +373,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
|
|||
# NNlib
|
||||
|
||||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool
|
||||
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
|
@ -384,8 +400,18 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
|||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:conv,
|
||||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
(NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x))...; size=size(w), kw...)))
|
||||
|
||||
∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
|
||||
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
|
||||
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)
|
||||
|
||||
@grad ∇conv_data(x, w; kw...) =
|
||||
∇conv_data(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:conv,
|
||||
(NNlib.conv(data.((Δ, w))...; size=size(x), kw...),
|
||||
NNlib.∇conv_filter(data.((x, Δ))...; size=size(w), kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
struct TrackedReal{T<:Real} <: Real
|
||||
mutable struct TrackedReal{T<:Real} <: Real
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
@ -16,6 +16,12 @@ function back!(x::TrackedReal; once = true)
|
|||
return back!(x, 1, once = once)
|
||||
end
|
||||
|
||||
function update!(x::TrackedReal, Δ)
|
||||
x.data += data(Δ)
|
||||
tracker(x).grad = 0
|
||||
return x
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
T = get(io, :typeinfo, Any)
|
||||
show(io, data(x))
|
||||
|
@ -33,6 +39,8 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
|
|||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
(T::Type{<:TrackedReal})(x::Real) = convert(T, x)
|
||||
|
||||
for op in [:(==), :≈, :<, :(<=)]
|
||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||
|
@ -53,6 +61,12 @@ Base.float(x::TrackedReal) = x
|
|||
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
||||
TrackedReal{promote_type(S,T)}
|
||||
|
||||
using Random
|
||||
|
||||
for f in :[rand, randn, randexp].args
|
||||
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
|
||||
end
|
||||
|
||||
using DiffRules, SpecialFunctions, NaNMath
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
|
@ -87,6 +101,13 @@ import Base:^
|
|||
|
||||
^(a::TrackedReal, b::Integer) = track(^, a, b)
|
||||
|
||||
# Hack for conversions
|
||||
|
||||
using ForwardDiff: Dual
|
||||
|
||||
(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values))
|
||||
(Dual{T,V,N})(x::Dual) where {T,V,N} = invoke(Dual{T,V,N}, Tuple{Number}, x)
|
||||
|
||||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
|
@ -134,3 +155,6 @@ end
|
|||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||
end
|
||||
|
||||
collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
|
||||
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import Adapt: adapt
|
||||
import Adapt: adapt, adapt_storage
|
||||
import .Tracker: IdSet
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
||||
children(x::Tuple) = x
|
||||
children(x::NamedTuple) = x
|
||||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
mapchildren(f, x::NamedTuple) = map(f, x)
|
||||
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
|
@ -14,11 +16,6 @@ function treelike(m::Module, T, fs = fieldnames(T))
|
|||
end
|
||||
end
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
|
||||
treelike(Base._current_module(), T, fs)
|
||||
end
|
||||
|
||||
macro treelike(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
|
@ -69,3 +66,22 @@ gpu_adaptor = identity
|
|||
end
|
||||
|
||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
||||
|
||||
# Precision
|
||||
|
||||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
||||
|
||||
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
||||
|
||||
f32(m) = paramtype(Float32, m)
|
||||
f64(m) = paramtype(Float64, m)
|
||||
|
||||
# General parameter map
|
||||
|
||||
function mapparams(f, m)
|
||||
mapleaves(m) do x
|
||||
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
||||
x isa Union{AbstractArray,Number} ? f(x) :
|
||||
x
|
||||
end
|
||||
end
|
||||
|
|
23
src/utils.jl
23
src/utils.jl
|
@ -10,8 +10,8 @@ zeros(dims...) = Base.zeros(Float32, dims...)
|
|||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
stack(xs, dim) = cat(unsqueeze.(xs, dim)...; dims=dim)
|
||||
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
|
||||
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
|
||||
|
||||
"""
|
||||
chunk(xs, n)
|
||||
|
@ -139,25 +139,6 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||
end
|
||||
end
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J'
|
||||
end
|
||||
|
||||
"""
|
||||
@jit ...
|
||||
|
||||
|
|
|
@ -11,6 +11,8 @@ x = param(randn(5, 5))
|
|||
cx = gpu(x)
|
||||
@test cx isa TrackedArray && cx.data isa CuArray
|
||||
|
||||
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
|
||||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
|
|
|
@ -21,3 +21,15 @@ end
|
|||
|
||||
@test size(m(r)) == (10, 5)
|
||||
end
|
||||
|
||||
@testset "Depthwise Conv" begin
|
||||
r = zeros(Float32, 28, 28, 3, 5)
|
||||
|
||||
m1 = DepthwiseConv((2, 2), 3=>5)
|
||||
|
||||
@test size(m1(r), 3) == 15
|
||||
|
||||
m2 = DepthwiseConv((2, 2), 3)
|
||||
|
||||
@test size(m2(r), 3) == 3
|
||||
end
|
||||
|
|
|
@ -98,4 +98,9 @@ end
|
|||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
end
|
||||
|
|
|
@ -49,4 +49,16 @@ const ϵ = 1e-7
|
|||
@testset "logitbinarycrossentropy" begin
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||
end
|
||||
|
||||
@testset "no spurious promotions" begin
|
||||
for T in (Float16, Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy)
|
||||
fwd, back = Flux.Tracker.forward(mse, ŷ, y)
|
||||
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
|
||||
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -4,7 +4,7 @@ using Flux.Tracker
|
|||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt(0.001)
|
||||
|
@ -17,7 +17,7 @@ using Test
|
|||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
|
@ -33,7 +33,7 @@ end
|
|||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
|
|
|
@ -1,18 +1,3 @@
|
|||
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
|
||||
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
|
||||
if Base.JLOptions().check_bounds == 1
|
||||
file = @__FILE__
|
||||
run(```
|
||||
$(Base.julia_cmd())
|
||||
--color=$(Base.have_color ? "yes" : "no")
|
||||
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
|
||||
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
|
||||
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
|
||||
$(file)
|
||||
```)
|
||||
exit()
|
||||
end
|
||||
|
||||
using Flux, Test, Random, Statistics
|
||||
using Random
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
|
||||
using NNlib: conv, depthwiseconv
|
||||
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
|
||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
|
||||
using Statistics: mean, std
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
@ -34,6 +34,10 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@test gradtest(det, (4, 4))
|
||||
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
|
||||
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
|
||||
|
||||
@testset "indexing & slicing" begin
|
||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||
end
|
||||
|
@ -113,9 +117,17 @@ end
|
|||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
|
||||
@testset "scalars" begin
|
||||
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
|
||||
@test vcat(1, param([1, 2, 3])) isa TrackedArray
|
||||
@test hcat(1, param([1 2 3;])) isa TrackedArray
|
||||
@test vcat(param(1), 2) isa TrackedArray
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
|
@ -166,6 +178,10 @@ end
|
|||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
|
||||
|
||||
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
||||
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
@ -177,18 +193,28 @@ end
|
|||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
|
||||
|
||||
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
|
||||
|
||||
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3))
|
||||
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
||||
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
||||
|
||||
@test gradtest(x -> Float64.(x), 5)
|
||||
|
||||
@testset "equality & order" begin
|
||||
# TrackedReal
|
||||
@test param(2)^2 == param(4)
|
||||
|
@ -260,7 +286,7 @@ Tracker.back!(b)
|
|||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
|
||||
@test Tracker.gradient(2, 3) do x, y
|
||||
@test gradient(2, 3) do x, y
|
||||
xy = Tracker.collect([x, y])
|
||||
xy[1]*xy[2]
|
||||
end == (3, 2)
|
||||
|
@ -286,4 +312,36 @@ end
|
|||
@test count == 3
|
||||
end
|
||||
|
||||
@testset "Updates" begin
|
||||
xs = param([1, 2, 3])
|
||||
Tracker.update!(xs, param([4, 5, 6]))
|
||||
@test xs == [5, 7, 9]
|
||||
x = param(3)
|
||||
Tracker.update!(x, param(4))
|
||||
@test x == 7
|
||||
end
|
||||
|
||||
@testset "Params" begin
|
||||
W = param(randn(5, 10))
|
||||
x = rand(10)
|
||||
dW = gradient(W -> sum(W*x), W)[1]
|
||||
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
|
||||
@test gs[W] == dW
|
||||
end
|
||||
|
||||
@testset "Forward" begin
|
||||
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
|
||||
reshape(ones(25), :, 1)
|
||||
@test gradient([2, 3]) do x
|
||||
forwarddiff(x) do x
|
||||
x[1]*x[2]
|
||||
end
|
||||
end == ([3, 2],)
|
||||
end
|
||||
|
||||
@testset "Custom Sensitivities" begin
|
||||
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
|
||||
@test back([1, 1]) == (32,)
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack
|
||||
using StatsBase: std
|
||||
using Random
|
||||
using Test
|
||||
|
@ -87,8 +87,27 @@ end
|
|||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
||||
@testset "Basic" begin
|
||||
@testset "Basic Stacking" begin
|
||||
x = randn(3,3)
|
||||
stacked = stack([x, x], 2)
|
||||
@test size(stacked) == (3,2,3)
|
||||
end
|
||||
|
||||
@testset "Precision" begin
|
||||
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||
x = rand(10)
|
||||
@test eltype(m[1].W.data) == Float32
|
||||
@test eltype(m(x).data) == Float32
|
||||
@test eltype(f64(m)(x).data) == Float64
|
||||
@test eltype(f64(m)[1].W.data) == Float64
|
||||
@test eltype(f32(f64(m))[1].W.data) == Float32
|
||||
@test Tracker.isleaf(f32(f64(m))[1].W)
|
||||
end
|
||||
|
||||
@testset "Stacking" begin
|
||||
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
|
||||
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
|
||||
@test unstack(stacked_array, 2) == unstacked_array
|
||||
@test stack(unstacked_array, 2) == stacked_array
|
||||
@test stack(unstack(stacked_array, 1), 1) == stacked_array
|
||||
end
|
Loading…
Reference in New Issue