Compare commits
1 Commits
master
...
sf/gpu_bat
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a740dadf6a |
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1,2 +1 @@
|
||||
paper/* linguist-documentation
|
||||
CITATION.bib linguist-detectable=false
|
||||
|
1
.github/FUNDING.yml
vendored
1
.github/FUNDING.yml
vendored
@ -1 +0,0 @@
|
||||
custom: https://numfocus.salsalabs.org/donate-to-julia/index.html
|
12
.github/pull_request_template.md
vendored
12
.github/pull_request_template.md
vendored
@ -1,12 +0,0 @@
|
||||
[Please delete this text and describe your change here.
|
||||
For bugfixes, please detail the bug and include a test case which your patch fixes.
|
||||
If you are adding a new feature, please clearly describe the design, its rationale, the possible alternatives considered.
|
||||
It is easiest to merge new features when there is clear precedent in other systems; we need to know we're taking
|
||||
the right direction since it can be hard to change later.]
|
||||
|
||||
### PR Checklist
|
||||
|
||||
- [ ] Tests are added
|
||||
- [ ] Entry in NEWS.md
|
||||
- [ ] Documentation, if applicable
|
||||
- [ ] Final review from `@MikeInnes` or `@dhairyagandhi96` (for API changes).
|
16
.github/workflows/CompatHelper.yml
vendored
16
.github/workflows/CompatHelper.yml
vendored
@ -1,16 +0,0 @@
|
||||
name: CompatHelper
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '00 00 * * *'
|
||||
|
||||
jobs:
|
||||
CompatHelper:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Pkg.add("CompatHelper")
|
||||
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
|
||||
- name: CompatHelper.main()
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: julia -e 'using CompatHelper; CompatHelper.main()'
|
11
.github/workflows/TagBot.yml
vendored
11
.github/workflows/TagBot.yml
vendored
@ -1,11 +0,0 @@
|
||||
name: TagBot
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 * * * *
|
||||
jobs:
|
||||
TagBot:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: JuliaRegistries/TagBot@v1
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,4 +3,5 @@
|
||||
*.jl.mem
|
||||
docs/build/
|
||||
docs/site/
|
||||
docs/flux.css
|
||||
deps
|
||||
|
@ -1,41 +0,0 @@
|
||||
include:
|
||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v6.yml'
|
||||
|
||||
image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
|
||||
|
||||
|
||||
# julia:1.0:
|
||||
# extends:
|
||||
# - .julia:1.0
|
||||
# - .test
|
||||
# tags:
|
||||
# - nvidia
|
||||
#
|
||||
# julia:1.1:
|
||||
# extends:
|
||||
# - .julia:1.1
|
||||
# - .test
|
||||
# tags:
|
||||
# - nvidia
|
||||
#
|
||||
# julia:1.2:
|
||||
# extends:
|
||||
# - .julia:1.2
|
||||
# - .test
|
||||
# tags:
|
||||
# - nvidia
|
||||
|
||||
julia:1.3:
|
||||
extends:
|
||||
- .julia:1.3
|
||||
- .test
|
||||
tags:
|
||||
- nvidia
|
||||
|
||||
julia:nightly:
|
||||
extends:
|
||||
- .julia:nightly
|
||||
- .test
|
||||
tags:
|
||||
- nvidia
|
||||
allow_failure: true
|
32
.travis.yml
32
.travis.yml
@ -1,32 +1,18 @@
|
||||
# Documentation: http://docs.travis-ci.com/user/languages/julia/
|
||||
language: julia
|
||||
|
||||
os:
|
||||
- linux
|
||||
# - osx
|
||||
|
||||
julia:
|
||||
- 1.3
|
||||
- 1
|
||||
- 1.0
|
||||
- nightly
|
||||
|
||||
notifications:
|
||||
email: false
|
||||
|
||||
jobs:
|
||||
include:
|
||||
- stage: "Documentation"
|
||||
julia: 1.3
|
||||
os: linux
|
||||
script:
|
||||
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd()));
|
||||
Pkg.instantiate()'
|
||||
- julia --project=docs/ docs/make.jl
|
||||
after_success: skip
|
||||
|
||||
# uncomment the following lines to override the default test script
|
||||
# script:
|
||||
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
matrix:
|
||||
allow_failures:
|
||||
- julia: nightly
|
||||
|
||||
## uncomment the following lines to override the default test script
|
||||
script:
|
||||
- julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()'
|
||||
after_success:
|
||||
- julia -e 'using Pkg; ps=Pkg.PackageSpec(name="Documenter", version="0.19"); Pkg.add(ps); Pkg.pin(ps); Pkg.add("NNlib")'
|
||||
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||
|
29
CITATION.bib
29
CITATION.bib
@ -1,29 +0,0 @@
|
||||
@article{Flux.jl-2018,
|
||||
author = {Michael Innes and
|
||||
Elliot Saba and
|
||||
Keno Fischer and
|
||||
Dhairya Gandhi and
|
||||
Marco Concetto Rudilosso and
|
||||
Neethu Mariya Joy and
|
||||
Tejan Karmali and
|
||||
Avik Pal and
|
||||
Viral Shah},
|
||||
title = {Fashionable Modelling with Flux},
|
||||
journal = {CoRR},
|
||||
volume = {abs/1811.01457},
|
||||
year = {2018},
|
||||
url = {http://arxiv.org/abs/1811.01457},
|
||||
archivePrefix = {arXiv},
|
||||
eprint = {1811.01457},
|
||||
timestamp = {Thu, 22 Nov 2018 17:58:30 +0100},
|
||||
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1811-01457},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
||||
|
||||
@article{innes:2018,
|
||||
author = {Mike Innes},
|
||||
title = {Flux: Elegant Machine Learning with Julia},
|
||||
journal = {Journal of Open Source Software},
|
||||
year = {2018},
|
||||
doi = {10.21105/joss.00602},
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
The Flux.jl package is licensed under the MIT "Expat" License:
|
||||
|
||||
> Copyright (c) 2016-19: Julia Computing, INc., Mike Innes and Contributors
|
||||
> Copyright (c) 2016: Mike Innes.
|
||||
>
|
||||
> Permission is hereby granted, free of charge, to any person obtaining
|
||||
> a copy of this software and associated documentation files (the
|
||||
|
297
Manifest.toml
297
Manifest.toml
@ -1,84 +1,49 @@
|
||||
# This file is machine-generated - editing it directly is not advised
|
||||
|
||||
[[AbstractFFTs]]
|
||||
deps = ["LinearAlgebra"]
|
||||
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
|
||||
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
|
||||
version = "0.5.0"
|
||||
|
||||
[[AbstractTrees]]
|
||||
deps = ["Markdown"]
|
||||
git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45"
|
||||
deps = ["Markdown", "Test"]
|
||||
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
|
||||
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
version = "0.3.3"
|
||||
version = "0.2.1"
|
||||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra"]
|
||||
git-tree-sha1 = "fd04049c7dd78cfef0b06cdc1f0f181467655712"
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "1.1.0"
|
||||
|
||||
[[ArrayLayouts]]
|
||||
deps = ["FillArrays", "LinearAlgebra"]
|
||||
git-tree-sha1 = "a504dca2ac7eda8761c8f7c1ed52427a1be75a3c"
|
||||
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
|
||||
version = "0.2.6"
|
||||
version = "0.4.1"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
||||
[[BinDeps]]
|
||||
deps = ["Compat", "Libdl", "SHA", "URIParser"]
|
||||
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
|
||||
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
|
||||
version = "0.8.10"
|
||||
|
||||
[[BinaryProvider]]
|
||||
deps = ["Libdl", "Logging", "SHA"]
|
||||
git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
|
||||
deps = ["Libdl", "Pkg", "SHA", "Test"]
|
||||
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.10"
|
||||
|
||||
[[CEnum]]
|
||||
git-tree-sha1 = "1b77a77c3b28e0b3f413f7567c9bb8dd9bdccd14"
|
||||
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
|
||||
version = "0.3.0"
|
||||
|
||||
[[CUDAapi]]
|
||||
deps = ["Libdl", "Logging"]
|
||||
git-tree-sha1 = "831b825d10104bd29e28f6da93312a976830717b"
|
||||
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||
version = "4.0.0"
|
||||
|
||||
[[CUDAdrv]]
|
||||
deps = ["CEnum", "CUDAapi", "Printf"]
|
||||
git-tree-sha1 = "f56bbf18c86bcff7a961a32a4947a5abb2963a29"
|
||||
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||
version = "6.3.0"
|
||||
|
||||
[[CUDAnative]]
|
||||
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "ExprTools", "GPUCompiler", "LLVM", "Libdl", "Pkg", "Printf"]
|
||||
git-tree-sha1 = "ac86db2b05fdfec96b011e25a504ffe7476e8a68"
|
||||
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
||||
version = "3.1.0"
|
||||
|
||||
[[CodeTracking]]
|
||||
deps = ["InteractiveUtils", "UUIDs"]
|
||||
git-tree-sha1 = "cab4da992adc0a64f63fa30d2db2fd8bec40cab4"
|
||||
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
|
||||
version = "0.5.11"
|
||||
version = "0.5.3"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["TranscodingStreams", "Zlib_jll"]
|
||||
git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.7.0"
|
||||
version = "0.5.1"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random"]
|
||||
git-tree-sha1 = "c73d9cfc2a9d8433dc77f5bff4bddf46b1d78c20"
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
|
||||
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
||||
version = "0.10.3"
|
||||
version = "0.7.5"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
|
||||
git-tree-sha1 = "1e9bba7984e78aa8cdeea7f9f7cc984ad4e4b1c7"
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
|
||||
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.12.2"
|
||||
version = "0.9.5"
|
||||
|
||||
[[CommonSubexpressions]]
|
||||
deps = ["Test"]
|
||||
@ -86,34 +51,17 @@ git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
|
||||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
|
||||
version = "0.2.0"
|
||||
|
||||
[[CompilerSupportLibraries_jll]]
|
||||
deps = ["Libdl", "Pkg"]
|
||||
git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612"
|
||||
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
|
||||
version = "0.3.3+0"
|
||||
|
||||
[[Cthulhu]]
|
||||
deps = ["CodeTracking", "InteractiveUtils", "REPL", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "f3643e78353199d3097821e806348bd83f364155"
|
||||
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
|
||||
version = "1.1.1"
|
||||
|
||||
[[CuArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
|
||||
git-tree-sha1 = "1582b74d2322df7dd94549d4ac9d095e0f20e884"
|
||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
version = "2.2.1"
|
||||
|
||||
[[DataAPI]]
|
||||
git-tree-sha1 = "176e23402d80e7743fc26c19c681bfb11246af32"
|
||||
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
|
||||
version = "1.3.0"
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "1.4.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections"]
|
||||
git-tree-sha1 = "af6d9c86e191c917c2276fbede1137e8ea20157f"
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "8fc6e166e24fda04b2b648d4260cdad241788c54"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.17.17"
|
||||
version = "0.14.0"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
@ -124,89 +72,44 @@ deps = ["Mmap"]
|
||||
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
|
||||
[[DiffResults]]
|
||||
deps = ["StaticArrays"]
|
||||
git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc"
|
||||
deps = ["Compat", "StaticArrays"]
|
||||
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
|
||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "1.0.2"
|
||||
version = "0.0.3"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["NaNMath", "Random", "SpecialFunctions"]
|
||||
git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1"
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "1.0.1"
|
||||
version = "0.0.7"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[ExprTools]]
|
||||
git-tree-sha1 = "6f0517056812fd6aa3af23d4b70d5325a2ae4e95"
|
||||
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
|
||||
version = "0.1.1"
|
||||
|
||||
[[FillArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
||||
git-tree-sha1 = "44f561e293987ffc84272cd3d2b14b0b93123d63"
|
||||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
||||
version = "0.8.10"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
git-tree-sha1 = "3ba9ea634d4c8b289d590403b4a06f8e227a6238"
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
|
||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
||||
version = "0.8.0"
|
||||
version = "0.5.3"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
|
||||
git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac"
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.10"
|
||||
|
||||
[[Functors]]
|
||||
deps = ["MacroTools"]
|
||||
git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a"
|
||||
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
|
||||
version = "0.1.0"
|
||||
|
||||
[[Future]]
|
||||
deps = ["Random"]
|
||||
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
|
||||
|
||||
[[GPUArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
|
||||
git-tree-sha1 = "d887693eb1bd5e1fd573262a978745481895ec7d"
|
||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||
version = "3.4.1"
|
||||
|
||||
[[GPUCompiler]]
|
||||
deps = ["Cthulhu", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs"]
|
||||
git-tree-sha1 = "5275aa268ecd09640b32560e1eae90c78816e4d1"
|
||||
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
|
||||
version = "0.2.0"
|
||||
|
||||
[[IRTools]]
|
||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||
git-tree-sha1 = "90ee39f9beaaa186e4968417ea2b8ed5673c91c0"
|
||||
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||
version = "0.3.3"
|
||||
version = "0.10.1"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile"]
|
||||
git-tree-sha1 = "a686b0cf235fa3e491b79b4783c2d2382292b436"
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.8.2"
|
||||
|
||||
[[LLVM]]
|
||||
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
|
||||
git-tree-sha1 = "dd3f584c3dbefe39b2a8fbafa1a3b77e31e21255"
|
||||
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
|
||||
version = "1.5.1"
|
||||
version = "0.5.3"
|
||||
|
||||
[[LibGit2]]
|
||||
deps = ["Printf"]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
||||
[[Libdl]]
|
||||
@ -220,10 +123,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
|
||||
[[MacroTools]]
|
||||
deps = ["Markdown", "Random"]
|
||||
git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a"
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.5.5"
|
||||
version = "0.4.4"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
@ -236,38 +139,34 @@ uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Missings]]
|
||||
deps = ["DataAPI"]
|
||||
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.4.3"
|
||||
version = "0.3.1"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
|
||||
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
|
||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.6.6"
|
||||
version = "0.4.3"
|
||||
|
||||
[[NaNMath]]
|
||||
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
|
||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
version = "0.3.3"
|
||||
|
||||
[[OpenSpecFun_jll]]
|
||||
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
|
||||
git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87"
|
||||
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
|
||||
version = "0.5.3+3"
|
||||
version = "0.3.2"
|
||||
|
||||
[[OrderedCollections]]
|
||||
git-tree-sha1 = "12ce190210d278e12644bcadf5b21cbdcf225cd3"
|
||||
deps = ["Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.2.0"
|
||||
version = "1.0.2"
|
||||
|
||||
[[Pkg]]
|
||||
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
|
||||
[[Printf]]
|
||||
@ -293,10 +192,10 @@ uuid = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Requires]]
|
||||
deps = ["UUIDs"]
|
||||
git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b"
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
|
||||
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
version = "1.0.1"
|
||||
version = "0.5.2"
|
||||
|
||||
[[SHA]]
|
||||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
@ -304,6 +203,10 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
[[Serialization]]
|
||||
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
||||
|
||||
[[SharedArrays]]
|
||||
deps = ["Distributed", "Mmap", "Random", "Serialization"]
|
||||
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
|
||||
|
||||
[[Sockets]]
|
||||
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
|
||||
|
||||
@ -318,42 +221,42 @@ deps = ["LinearAlgebra", "Random"]
|
||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["OpenSpecFun_jll"]
|
||||
git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020"
|
||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
||||
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.10.3"
|
||||
version = "0.7.2"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "Statistics"]
|
||||
git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0"
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "97c4bf0f647488dd7ac01ea12be5885f88762938"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.12.3"
|
||||
version = "0.10.0"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||
git-tree-sha1 = "a6102b1f364befdb05746f386b67c6b7e3262c45"
|
||||
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "2722397d88f8ffef551948f6c20e1d74a743298c"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.33.0"
|
||||
version = "0.26.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TimerOutputs]]
|
||||
deps = ["Printf"]
|
||||
git-tree-sha1 = "f458ca23ff80e46a630922c555d838303e4b9603"
|
||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||
version = "0.5.6"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
|
||||
deps = ["Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.9.5"
|
||||
version = "0.8.1"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
|
||||
uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
||||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random", "SHA"]
|
||||
@ -363,25 +266,7 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||
|
||||
[[ZipFile]]
|
||||
deps = ["Libdl", "Printf", "Zlib_jll"]
|
||||
git-tree-sha1 = "254975fef2fc526583bb9b7c9420fe66ffe09f2f"
|
||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.9.2"
|
||||
|
||||
[[Zlib_jll]]
|
||||
deps = ["Libdl", "Pkg"]
|
||||
git-tree-sha1 = "a2e0d558f6031002e380a90613b199e37a8565bf"
|
||||
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
|
||||
version = "1.2.11+10"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "707ceea58e2bd0ff3077ab13a92f8355181d3ee4"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.4.20"
|
||||
|
||||
[[ZygoteRules]]
|
||||
deps = ["MacroTools"]
|
||||
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
|
||||
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||
version = "0.2.0"
|
||||
version = "0.8.0"
|
||||
|
61
NEWS.md
61
NEWS.md
@ -1,61 +0,0 @@
|
||||
# v0.11
|
||||
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
|
||||
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
|
||||
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].
|
||||
|
||||
# v0.10.5
|
||||
* Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`.
|
||||
* Added option to set `bias` to [Flux.Zeros](https://github.com/FluxML/Flux.jl/pull/873) to eliminating `bias` from being trained.
|
||||
* Added `GlobalMaxPool` and `GlobalMeanPool` [layers](https://github.com/FluxML/Flux.jl/pull/950) for performing global pooling operations.
|
||||
* Added `ClipValue` and `ClipNorm` in this [pr](https://github.com/FluxML/Flux.jl/pull/1133) to `Flux.Optimise` to provide a cleaner API for gradient clipping.
|
||||
* Added new kwarg-only [constructors](https://github.com/FluxML/Flux.jl/pull/873) for the various convolutional layers.
|
||||
* Documented the convolutional layer constructors accepting `weight` and `bias` keyword arguments to supply custom arrays for those fields.
|
||||
* Testing suite improvements now test for gradients of all layers along with GPU support.
|
||||
* Functors have now moved to [Functors.jl](https://github.com/FluxML/Flux.jl/pull/1174) to allow for their use outside of Flux.
|
||||
* Added [helper functions](https://github.com/FluxML/Flux.jl/pull/873) `Flux.convfilter` and `Flux.depthwiseconvfilter` to construct weight arrays for convolutions outside of layer constructors so as to not have to depend on the default layers for custom implementations.
|
||||
|
||||
# v0.10.0
|
||||
* The default AD engine has switched from [Tracker to Zygote.jl](https://github.com/FluxML/Flux.jl/pull/669)
|
||||
- The dependency on Tracker.jl has been removed.
|
||||
- This means Flux now does not depend on using a specialised `TrackedArray` type, and can be used with normal Array implementations directly.
|
||||
- Tracker compatibility is maintained in most common cases, but Zygote will be the preferred AD backend for Flux from now on.
|
||||
* The CUDNN wrappers have been [moved from Flux into CuArrays](https://github.com/FluxML/Flux.jl/pull/874), to allow for better supporting the CUDA backend, and improve user experience, not to mention making Flux lean.
|
||||
* `*crossentropy` functions now [work as expected with CuArrays](https://github.com/FluxML/Flux.jl/pull/926). [PR for binarycrossentropy](https://github.com/FluxML/Flux.jl/pull/940).
|
||||
* Added [clearer docs](https://github.com/FluxML/Flux.jl/pull/904) around training and the Optimiser interface.
|
||||
* [Layer initialisations](https://github.com/FluxML/Flux.jl/pull/937) have been improved with a clearer API on how to extend it for other purposes.
|
||||
* [Better messaging around CUDA availability](https://github.com/FluxML/Flux.jl/pull/924), with hooks to initialize the GPU as default where possible.
|
||||
* `@treelike` has been formalised as a [functor](https://github.com/FluxML/Flux.jl/pull/865), with an effective deprecation.
|
||||
* `testmode!` is deprecated in favour of [istraining](https://github.com/FluxML/Flux.jl/pull/669)
|
||||
|
||||
# v0.9.0
|
||||
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
|
||||
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
|
||||
* New [RADAM](https://github.com/FluxML/Flux.jl/pull/842) optimiser.
|
||||
|
||||
# v0.8.0
|
||||
|
||||
* [Dropout now has a `dims` argument for specifying the unbroadcast dimensions.](https://github.com/FluxML/Flux.jl/pull/563)
|
||||
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
|
||||
* New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647)
|
||||
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
|
||||
* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/).
|
||||
* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592)
|
||||
* Many docs and bugfixes thanks to @KristofferC and others.
|
||||
* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`.
|
||||
* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615).
|
||||
* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs.
|
||||
* New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656).
|
||||
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
|
||||
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||
* New [CrossCor](https://github.com/FluxML/Flux.jl/pull/762).
|
||||
|
||||
AD Changes:
|
||||
|
||||
* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files).
|
||||
* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576).
|
||||
* Flux.Tracker is now its [own package](https://github.com/FluxML/Tracker.jl), in preparation for replacing it with Zygote.
|
||||
|
||||
# v0.7.0
|
||||
|
||||
Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.
|
37
Project.toml
37
Project.toml
@ -1,51 +1,24 @@
|
||||
name = "Flux"
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.11.0-DEV"
|
||||
|
||||
[deps]
|
||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
|
||||
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
||||
[compat]
|
||||
AbstractTrees = "0.2, 0.3"
|
||||
Adapt = "1, 2.0"
|
||||
CodecZlib = "0.5, 0.6, 0.7"
|
||||
Colors = "0.8, 0.9, 0.10, 0.11, 0.12"
|
||||
CuArrays = "2"
|
||||
Functors = "0.1"
|
||||
Juno = "0.5, 0.6, 0.7, 0.8"
|
||||
MacroTools = "0.3, 0.4, 0.5"
|
||||
NNlib = "0.6"
|
||||
Reexport = "0.2"
|
||||
StatsBase = "0"
|
||||
ZipFile = "0.7, 0.8, 0.9"
|
||||
Zygote = "0.4.13"
|
||||
julia = "1.3"
|
||||
|
||||
[extras]
|
||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[targets]
|
||||
test = ["Test", "Documenter", "IterTools", "LinearAlgebra"]
|
||||
|
92
README.md
92
README.md
@ -2,14 +2,98 @@
|
||||
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
|
||||
</p>
|
||||
|
||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/) [](https://doi.org/10.21105/joss.00602)
|
||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/) [](https://doi.org/10.21105/joss.00602)
|
||||
|
||||
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
||||
|
||||
```julia
|
||||
] add Flux
|
||||
julia> Pkg.add("Flux")
|
||||
```
|
||||
|
||||
See the [documentation](https://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||
|
||||
If you use Flux in your research, please [cite](CITATION.bib) our work.
|
||||
If you use Flux in research, please cite the following paper:
|
||||
|
||||
```
|
||||
@article{innes:2018,
|
||||
author = {Mike Innes},
|
||||
title = {Flux: Elegant Machine Learning with Julia},
|
||||
journal = {Journal of Open Source Software},
|
||||
year = {2018},
|
||||
doi = {10.21105/joss.00602},
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
Flux has powerful high-level features, and common architectures can be defined in a few lines.
|
||||
|
||||
```julia
|
||||
model = Chain(
|
||||
Dense(768, 128, σ),
|
||||
LSTM(128, 256),
|
||||
LSTM(256, 128),
|
||||
Dense(128, 10),
|
||||
softmax)
|
||||
|
||||
loss(x, y) = crossentropy(model(x), y)
|
||||
|
||||
Flux.train!(loss, data, ADAM(...))
|
||||
```
|
||||
|
||||
Yet you can easily strip away the layers, and directly write the mathematics for your problem. Flux will seamlessly take gradients of any Julia code, so your model looks just like the paper.
|
||||
|
||||
```julia
|
||||
W = param(randn(2, 10))
|
||||
b = param(randn(2))
|
||||
|
||||
y(x) = σ.(W * x .+ b)
|
||||
```
|
||||
|
||||
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels with [CUDAnative](https://github.com/JuliaGPU/CUDAnative.jl)! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
|
||||
|
||||
```julia
|
||||
function gpu_add(a, b, c)
|
||||
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
|
||||
c[i] = a[i] + b[i]
|
||||
return nothing
|
||||
end
|
||||
```
|
||||
|
||||
Unusual architectures are no problem in Flux, as you can use all the loops, control flow and even macros that you're used to. Here's a Tree RNN in 4 lines.
|
||||
|
||||
```julia
|
||||
tree() = rand() < 0.5 ? rand(10) : (tree(), tree()) # dummy data
|
||||
|
||||
shrink = Dense(20, 10)
|
||||
combine(a, b) = shrink([a; b])
|
||||
|
||||
model(x) = x
|
||||
model(x::Tuple) = combine(model(x[1]), model(x[2]))
|
||||
|
||||
model(tree()) # Sample output
|
||||
```
|
||||
|
||||
Despite this flexibility, Julia's advanced compiler lets us do some powerful optimisations. For example, this definition of `sigmoid` automatically gets fused into a *single* GPU kernel – so it's really fast.
|
||||
|
||||
```julia
|
||||
sigmoid(xs) = 1 ./ (1 .+ exp.(.-xs))
|
||||
```
|
||||
|
||||
Similarly, Flux is the first dynamic framework to support [compiling to the browser](https://fluxml.github.io/experiments/) and model import via [formats like ONNX](https://github.com/FluxML/ONNX.jl/), both of which are thinly-veiled compiler problems.
|
||||
|
||||
For more on our philosophy on machine learning, check out our article [On Machine Learning & Programming Languages](https://julialang.org/blog/2017/12/ml&pl).
|
||||
|
||||
## Contributing & Help
|
||||
|
||||
For general questions and help, check out Julia's [community forum](https://discourse.julialang.org/c/domain/ML).
|
||||
|
||||
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
|
||||
|
||||
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
|
||||
|
||||
## Related Packages
|
||||
|
||||
Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
|
||||
|
||||
[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.
|
||||
|
18
REQUIRE
Normal file
18
REQUIRE
Normal file
@ -0,0 +1,18 @@
|
||||
julia 1.0
|
||||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
Requires
|
||||
Adapt 0.4
|
||||
CodecZlib
|
||||
Colors
|
||||
ZipFile
|
||||
AbstractTrees
|
||||
Reexport
|
||||
StatsBase
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
DiffRules
|
||||
SpecialFunctions
|
||||
NaNMath
|
@ -1,6 +0,0 @@
|
||||
[deps]
|
||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
|
||||
[compat]
|
||||
Documenter = "0.24"
|
37
docs/make.jl
37
docs/make.jl
@ -1,36 +1,31 @@
|
||||
using Documenter, Flux, NNlib
|
||||
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
makedocs(modules=[Flux, NNlib],
|
||||
doctest = VERSION >= v"1.4",
|
||||
doctest = false,
|
||||
format = :html,
|
||||
analytics = "UA-36890222-9",
|
||||
sitename = "Flux",
|
||||
assets = ["../flux.css"],
|
||||
pages = ["Home" => "index.md",
|
||||
"Building Models" =>
|
||||
["Basics" => "models/basics.md",
|
||||
"Recurrence" => "models/recurrence.md",
|
||||
"Regularisation" => "models/regularisation.md",
|
||||
"Model Reference" => "models/layers.md",
|
||||
"Advanced Model Building" => "models/advanced.md",
|
||||
"NNlib" => "models/nnlib.md"],
|
||||
"Handling Data" =>
|
||||
["One-Hot Encoding" => "data/onehot.md",
|
||||
"DataLoader" => "data/dataloader.md"],
|
||||
"Model Reference" => "models/layers.md"],
|
||||
"Training Models" =>
|
||||
["Optimisers" => "training/optimisers.md",
|
||||
"Training" => "training/training.md"],
|
||||
"One-Hot Encoding" => "data/onehot.md",
|
||||
"GPU Support" => "gpu.md",
|
||||
"Saving & Loading" => "saving.md",
|
||||
"The Julia Ecosystem" => "ecosystem.md",
|
||||
"Utility Functions" => "utilities.md",
|
||||
"Performance Tips" => "performance.md",
|
||||
"Datasets" => "datasets.md",
|
||||
"Community" => "community.md"],
|
||||
format = Documenter.HTML(
|
||||
analytics = "UA-36890222-9",
|
||||
assets = ["assets/flux.css"],
|
||||
prettyurls = get(ENV, "CI", nothing) == "true"),
|
||||
)
|
||||
"Internals" =>
|
||||
["Backpropagation" => "internals/tracker.md"],
|
||||
"Community" => "community.md"])
|
||||
|
||||
deploydocs(repo = "github.com/FluxML/Flux.jl.git",
|
||||
target = "build",
|
||||
push_preview = true)
|
||||
deploydocs(
|
||||
repo = "github.com/FluxML/Flux.jl.git",
|
||||
target = "build",
|
||||
osname = "linux",
|
||||
julia = "1.0",
|
||||
deps = nothing,
|
||||
make = nothing)
|
||||
|
@ -1,113 +0,0 @@
|
||||
@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');
|
||||
|
||||
body {
|
||||
font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
|
||||
}
|
||||
|
||||
nav.toc {
|
||||
padding-top: 0;
|
||||
background: rgb(240, 240, 240);
|
||||
line-height: 2em;
|
||||
cursor: default;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
h1+h2 {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* Green banner in ToC */
|
||||
nav.toc > h1 {
|
||||
margin-top: 0;
|
||||
padding-top: 0.4em;
|
||||
padding-bottom: 0.5em;
|
||||
border-bottom: 5px solid white;
|
||||
box-shadow: 0px -2px 5px rgb(60,60,60);
|
||||
margin-bottom: 0.5em;
|
||||
background: rgb(60, 150, 60);
|
||||
|
||||
font-style: italic;
|
||||
font-weight: normal;
|
||||
font-size: 50pt;
|
||||
text-transform: lowercase;
|
||||
text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
|
||||
color: white;
|
||||
}
|
||||
|
||||
/* Reduce ToC font size */
|
||||
.toctext {
|
||||
font-size: 10pt;
|
||||
}
|
||||
|
||||
/* Fade out non-clickable ToC headers */
|
||||
nav.toc ul span.toctext {
|
||||
color: rgb(180, 180, 180);
|
||||
}
|
||||
|
||||
nav.toc ul .toctext {
|
||||
color: rgb(100, 100, 100);
|
||||
}
|
||||
|
||||
nav.toc ul a.toctext:hover {
|
||||
color: inherit;
|
||||
background: rgb(220, 220, 220);
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
nav.toc li.current > .toctext {
|
||||
background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
nav.toc ul.internal li.toplevel {
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
/* Content */
|
||||
|
||||
article { max-width: none; }
|
||||
|
||||
article > p, article > ul {
|
||||
max-width: 45em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
a, a:visited { color: rgb(0, 120, 0); }
|
||||
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
|
||||
a:hover, a:visited:hover { color: rgb(0, 80, 0); }
|
||||
|
||||
/* Article Links */
|
||||
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
|
||||
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
|
||||
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }
|
||||
|
||||
/* Doctstrings */
|
||||
article section.docstring {
|
||||
padding: 0.5em 0;
|
||||
border-left: none;
|
||||
border-right: none;
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
/* Code */
|
||||
|
||||
article pre, article p > code {
|
||||
background: rgb(245, 250, 245);
|
||||
}
|
||||
|
||||
article pre {
|
||||
border: none;
|
||||
max-width: none;
|
||||
padding: 1em;
|
||||
border-radius: 10px 0px 0px 10px;
|
||||
margin-left: -1em;
|
||||
margin-right: -2em;
|
||||
}
|
||||
|
||||
.hljs-comment {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.hljs-number {
|
||||
color: rgb(0, 150, 150);
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
# Community
|
||||
|
||||
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), or the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning). If you have questions or issues we'll try to help you out.
|
||||
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby). If you have questions or issues we'll try to help you out.
|
||||
|
||||
If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started.
|
||||
|
@ -1,6 +0,0 @@
|
||||
# DataLoader
|
||||
Flux provides the `DataLoader` type in the `Flux.Data` module to handle iteration over mini-batches of data.
|
||||
|
||||
```@docs
|
||||
Flux.Data.DataLoader
|
||||
```
|
@ -7,15 +7,15 @@ julia> using Flux: onehot, onecold
|
||||
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
1
|
||||
0
|
||||
false
|
||||
true
|
||||
false
|
||||
|
||||
julia> onehot(:c, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
0
|
||||
1
|
||||
false
|
||||
false
|
||||
true
|
||||
```
|
||||
|
||||
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||
@ -31,11 +31,6 @@ julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
:c
|
||||
```
|
||||
|
||||
```@docs
|
||||
Flux.onehot
|
||||
Flux.onecold
|
||||
```
|
||||
|
||||
## Batches
|
||||
|
||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||
@ -57,7 +52,3 @@ julia> onecold(ans, [:a, :b, :c])
|
||||
```
|
||||
|
||||
Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.
|
||||
|
||||
```@docs
|
||||
Flux.onehotbatch
|
||||
```
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Datasets
|
||||
|
||||
Flux includes several standard machine learning datasets.
|
||||
|
||||
```@docs
|
||||
Flux.Data.Iris.features()
|
||||
Flux.Data.Iris.labels()
|
||||
Flux.Data.MNIST.images()
|
||||
Flux.Data.MNIST.labels()
|
||||
Flux.Data.FashionMNIST.images()
|
||||
Flux.Data.FashionMNIST.labels()
|
||||
Flux.Data.CMUDict.phones()
|
||||
Flux.Data.CMUDict.symbols()
|
||||
Flux.Data.CMUDict.rawdict()
|
||||
Flux.Data.CMUDict.cmudict()
|
||||
Flux.Data.Sentiment.train()
|
||||
Flux.Data.Sentiment.test()
|
||||
Flux.Data.Sentiment.dev()
|
||||
```
|
||||
|
@ -1,21 +0,0 @@
|
||||
# The Julia Ecosystem
|
||||
|
||||
One of the main strengths of Julia lies in an ecosystem of packages
|
||||
globally providing a rich and consistent user experience.
|
||||
|
||||
This is a non-exhaustive list of Julia packages, nicely complementing `Flux` in typical
|
||||
machine learning and deep learning workflows:
|
||||
|
||||
- [ArgParse.jl](https://github.com/carlobaldassi/ArgParse.jl): package for parsing command-line arguments to Julia programs.
|
||||
- [Augmentor.jl](https://github.com/Evizero/Augmentor.jl): a fast image augmentation library in Julia for machine learning.
|
||||
- [BSON.jl](https://github.com/JuliaIO/BSON.jl): package for working with the Binary JSON serialisation format
|
||||
- [DataFrames.jl](https://github.com/joshday/OnlineStats.jl): in-memory tabular data in Julia
|
||||
- [DrWatson.jl](https://github.com/JuliaDynamics/DrWatson.jl): a scientific project assistant software
|
||||
- [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl): utility package for accessing common machine learning datasets
|
||||
- [OnlineStats.jl](https://github.com/joshday/OnlineStats.jl): single-pass algorithms for statistics
|
||||
- [Parameters.jl](https://github.com/mauro3/Parameters.jl): types with default field values, keyword constructors and (un-)pack macros
|
||||
- [ProgressMeters.jl](https://github.com/timholy/ProgressMeter.jl): progress meters for long-running computations
|
||||
- [TensorBoardLogger.jl](https://github.com/PhilipVinc/TensorBoardLogger.jl): easy peasy logging to [tensorboard](https://www.tensorflow.org/tensorboard) in Julia
|
||||
|
||||
|
||||
This tight integration among Julia pakages is shown in some of the examples in the [model-zoo](https://github.com/FluxML/model-zoo) repository.
|
@ -1,9 +1,5 @@
|
||||
# GPU Support
|
||||
|
||||
NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CuArrays](https://github.com/JuliaGPU/CuArrays.jl) readme.
|
||||
|
||||
## GPU Usage
|
||||
|
||||
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
|
||||
|
||||
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
||||
@ -25,16 +21,16 @@ loss(x, y) # ~ 3
|
||||
|
||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
|
||||
|
||||
```julia
|
||||
d = Dense(10, 5, σ)
|
||||
d = fmap(cu, d)
|
||||
d.W # CuArray
|
||||
d = mapleaves(cu, d)
|
||||
d.W # Tracked CuArray
|
||||
d(cu(rand(10))) # CuArray output
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
m = fmap(cu, m)
|
||||
m = mapleaves(cu, m)
|
||||
d(cu(rand(10)))
|
||||
```
|
||||
|
||||
@ -53,7 +49,7 @@ julia> x = rand(10) |> gpu
|
||||
0.511655
|
||||
|
||||
julia> m(x)
|
||||
5-element CuArray{Float32,1}:
|
||||
Tracked 5-element CuArray{Float32,1}:
|
||||
-0.30535
|
||||
⋮
|
||||
-0.618002
|
||||
|
184
docs/src/internals/tracker.md
Normal file
184
docs/src/internals/tracker.md
Normal file
@ -0,0 +1,184 @@
|
||||
# Flux.Tracker
|
||||
|
||||
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
|
||||
|
||||
```julia
|
||||
julia> using Flux.Tracker
|
||||
```
|
||||
|
||||
Here we discuss some more advanced uses of this module, as well as covering its internals.
|
||||
|
||||
## Taking Gradients
|
||||
|
||||
In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
|
||||
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: forward
|
||||
|
||||
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
|
||||
|
||||
back(1) # (3.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
|
||||
|
||||
```julia
|
||||
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
|
||||
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
|
||||
|
||||
julia> back([1,1,1])
|
||||
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
|
||||
```
|
||||
|
||||
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
|
||||
|
||||
```julia
|
||||
a, b = param(2), param(3)
|
||||
|
||||
c = a*b # 6.0 (tracked)
|
||||
|
||||
Tracker.back!(c)
|
||||
|
||||
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
|
||||
```
|
||||
|
||||
## Tracked Arrays
|
||||
|
||||
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
|
||||
|
||||
```julia
|
||||
julia> W = param([1 2; 3 4])
|
||||
Tracked 2×2 Array{Float64,2}:
|
||||
1.0 2.0
|
||||
3.0 4.0
|
||||
|
||||
julia> x = param([5, 6])
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
5.0
|
||||
6.0
|
||||
|
||||
julia> y = W*x
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
17.0
|
||||
39.0
|
||||
```
|
||||
|
||||
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
|
||||
|
||||
```julia
|
||||
julia> Tracker.back!(y, [1, -1])
|
||||
|
||||
julia> W.grad
|
||||
2×2 Array{Float64,2}:
|
||||
5.0 6.0
|
||||
-5.0 -6.0
|
||||
|
||||
julia> x.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`.
|
||||
|
||||
## Custom Gradients
|
||||
|
||||
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
|
||||
|
||||
```julia
|
||||
minus(a, b) = a - b
|
||||
```
|
||||
|
||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: TrackedArray, track, @grad
|
||||
|
||||
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
|
||||
```
|
||||
|
||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
||||
|
||||
```julia
|
||||
@grad function minus(a, b)
|
||||
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
|
||||
end
|
||||
```
|
||||
|
||||
This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress).
|
||||
|
||||
Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this.
|
||||
|
||||
```julia
|
||||
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
||||
```
|
||||
|
||||
We can then calculate the first derivative of `minus` as follows:
|
||||
|
||||
```julia
|
||||
a = param([1,2,3])
|
||||
b = param([3,2,1])
|
||||
|
||||
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
|
||||
|
||||
Tracker.back!(c, 1)
|
||||
Tracker.grad(a) # [1.00, 1.00, 1.00]
|
||||
Tracker.grad(b) # [-1.00, -1.00, -1.00]
|
||||
```
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
## Tracked Internals
|
||||
|
||||
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
|
||||
|
||||
```julia
|
||||
julia> x.tracker
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
```
|
||||
|
||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
||||
|
||||
```julia
|
||||
julia> x.tracker.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
|
||||
|
||||
```julia
|
||||
julia> Tracker.Call(+, 1, 2)
|
||||
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
|
||||
```
|
||||
|
||||
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
|
||||
|
||||
```julia
|
||||
julia> y.tracker.f
|
||||
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
|
||||
```
|
||||
|
||||
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
|
||||
|
||||
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
|
||||
|
||||
```julia
|
||||
Tracker.back(*, [1, -1], W, x)
|
||||
```
|
||||
|
||||
which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
|
@ -1,73 +0,0 @@
|
||||
# Advanced Model Building and Customisation
|
||||
|
||||
Here we will try and describe usage of some more advanced features that Flux provides to give more control over model building.
|
||||
|
||||
## Customising Parameter Collection for a Model
|
||||
|
||||
Taking reference from our example `Affine` layer from the [basics](basics.md#Building-Layers-1).
|
||||
|
||||
By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, it is possible to mark the fields of our layers that are trainable in two ways.
|
||||
|
||||
The first way of achieving this is through overloading the `trainable` function.
|
||||
|
||||
```julia-repl
|
||||
julia> @functor Affine
|
||||
|
||||
julia> a = Affine(rand(3,3), rand(3))
|
||||
Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955])
|
||||
|
||||
julia> Flux.params(a) # default behavior
|
||||
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
|
||||
|
||||
julia> Flux.trainable(a::Affine) = (a.W,)
|
||||
|
||||
julia> Flux.params(a)
|
||||
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])
|
||||
```
|
||||
|
||||
Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`.
|
||||
|
||||
Another way of achieving this is through the `@functor` macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:
|
||||
|
||||
```julia
|
||||
Flux.@functor Affine (W,)
|
||||
```
|
||||
|
||||
However, doing this requires the `struct` to have a corresponding constructor that accepts those parameters.
|
||||
|
||||
## Freezing Layer Parameters
|
||||
|
||||
When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.
|
||||
|
||||
Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
|
||||
this using the slicing features `Chain` provides:
|
||||
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(784, 64, relu),
|
||||
Dense(64, 64, relu),
|
||||
Dense(32, 10)
|
||||
)
|
||||
|
||||
ps = Flux.params(m[3:end])
|
||||
```
|
||||
|
||||
The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.
|
||||
|
||||
During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.
|
||||
|
||||
`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:
|
||||
|
||||
```julia
|
||||
Flux.params(m[1], m[3:end])
|
||||
```
|
||||
|
||||
Sometimes, a more fine-tuned control is needed.
|
||||
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
|
||||
by simply deleting it from `ps`:
|
||||
|
||||
```julia
|
||||
ps = params(m)
|
||||
delete!(ps, m[2].b)
|
||||
```
|
||||
|
@ -4,55 +4,49 @@
|
||||
|
||||
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
||||
|
||||
```jldoctest basics
|
||||
julia> using Flux
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
|
||||
julia> f(x) = 3x^2 + 2x + 1;
|
||||
f(x) = 3x^2 + 2x + 1
|
||||
|
||||
julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
|
||||
# df/dx = 6x + 2
|
||||
df(x) = Tracker.gradient(f, x)[1]
|
||||
|
||||
julia> df(2)
|
||||
14
|
||||
df(2) # 14.0 (tracked)
|
||||
|
||||
julia> d2f(x) = gradient(df, x)[1]; # d²f/dx² = 6
|
||||
# d²f/dx² = 6
|
||||
d2f(x) = Tracker.gradient(df, x)[1]
|
||||
|
||||
julia> d2f(2)
|
||||
6
|
||||
d2f(2) # 6.0 (tracked)
|
||||
```
|
||||
|
||||
When a function has many parameters, we can get gradients of each one at the same time:
|
||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||
|
||||
```jldoctest basics
|
||||
julia> f(x, y) = sum((x .- y).^2);
|
||||
When a function has many parameters, we can pass them all in explicitly:
|
||||
|
||||
julia> gradient(f, [2, 1], [2, 0])
|
||||
([0, 2], [0, -2])
|
||||
```julia
|
||||
f(W, b, x) = W * x + b
|
||||
|
||||
Tracker.gradient(f, 2, 3, 4)
|
||||
(4.0 (tracked), 1.0, 2.0 (tracked))
|
||||
```
|
||||
|
||||
But machine learning models can have *hundreds* of parameters! To handle this, Flux lets you work with collections of parameters, via `params`. You can get the gradient of all parameters used in a program without explicitly passing them in.
|
||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
|
||||
|
||||
```jldoctest basics
|
||||
julia> x = [2, 1];
|
||||
```julia
|
||||
W = param(2) # 2.0 (tracked)
|
||||
b = param(3) # 3.0 (tracked)
|
||||
|
||||
julia> y = [2, 0];
|
||||
f(x) = W * x + b
|
||||
|
||||
julia> gs = gradient(params(x, y)) do
|
||||
f(x, y)
|
||||
end
|
||||
Grads(...)
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> f(4), params)
|
||||
|
||||
julia> gs[x]
|
||||
2-element Array{Int64,1}:
|
||||
0
|
||||
2
|
||||
|
||||
julia> gs[y]
|
||||
2-element Array{Int64,1}:
|
||||
0
|
||||
-2
|
||||
grads[W] # 4.0
|
||||
grads[b] # 1.0
|
||||
```
|
||||
|
||||
Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
|
||||
|
||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||
|
||||
@ -67,28 +61,34 @@ b = rand(2)
|
||||
predict(x) = W*x .+ b
|
||||
|
||||
function loss(x, y)
|
||||
ŷ = predict(x)
|
||||
sum((y .- ŷ).^2)
|
||||
ŷ = predict(x)
|
||||
sum((y .- ŷ).^2)
|
||||
end
|
||||
|
||||
x, y = rand(5), rand(2) # Dummy data
|
||||
loss(x, y) # ~ 3
|
||||
```
|
||||
|
||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent.
|
||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above.
|
||||
|
||||
```julia
|
||||
using Flux
|
||||
using Flux.Tracker
|
||||
|
||||
gs = gradient(() -> loss(x, y), params(W, b))
|
||||
W = param(W)
|
||||
b = param(b)
|
||||
|
||||
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
|
||||
```
|
||||
|
||||
Now that we have gradients, we can pull them out and update `W` to train the model.
|
||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
||||
|
||||
```julia
|
||||
W̄ = gs[W]
|
||||
using Flux.Tracker: update!
|
||||
|
||||
W .-= 0.1 .* W̄
|
||||
Δ = gs[W]
|
||||
|
||||
# Update the parameter and reset the gradient
|
||||
update!(W, -0.1Δ)
|
||||
|
||||
loss(x, y) # ~ 2.5
|
||||
```
|
||||
@ -102,14 +102,12 @@ All deep learning in Flux, however complex, is a simple generalisation of this e
|
||||
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as:
|
||||
|
||||
```julia
|
||||
using Flux
|
||||
|
||||
W1 = rand(3, 5)
|
||||
b1 = rand(3)
|
||||
W1 = param(rand(3, 5))
|
||||
b1 = param(rand(3))
|
||||
layer1(x) = W1 * x .+ b1
|
||||
|
||||
W2 = rand(2, 3)
|
||||
b2 = rand(2)
|
||||
W2 = param(rand(2, 3))
|
||||
b2 = param(rand(2))
|
||||
layer2(x) = W2 * x .+ b2
|
||||
|
||||
model(x) = layer2(σ.(layer1(x)))
|
||||
@ -121,8 +119,8 @@ This works but is fairly unwieldy, with a lot of repetition – especially as we
|
||||
|
||||
```julia
|
||||
function linear(in, out)
|
||||
W = randn(out, in)
|
||||
b = randn(out)
|
||||
W = param(randn(out, in))
|
||||
b = param(randn(out))
|
||||
x -> W * x .+ b
|
||||
end
|
||||
|
||||
@ -143,7 +141,7 @@ struct Affine
|
||||
end
|
||||
|
||||
Affine(in::Integer, out::Integer) =
|
||||
Affine(randn(out, in), randn(out))
|
||||
Affine(param(randn(out, in)), param(randn(out)))
|
||||
|
||||
# Overload call, so the object can be used as a function
|
||||
(m::Affine)(x) = m.W * x .+ m.b
|
||||
@ -213,30 +211,7 @@ m(5) # => 26
|
||||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||
|
||||
```julia
|
||||
Flux.@functor Affine
|
||||
Flux.@treelike Affine
|
||||
```
|
||||
|
||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||
|
||||
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
|
||||
|
||||
## Utility functions
|
||||
|
||||
Flux provides some utility functions to help you generate models in an automated fashion.
|
||||
|
||||
`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size.
|
||||
Currently limited to the following layers:
|
||||
- `Chain`
|
||||
- `Dense`
|
||||
- `Conv`
|
||||
- `Diagonal`
|
||||
- `Maxout`
|
||||
- `ConvTranspose`
|
||||
- `DepthwiseConv`
|
||||
- `CrossCor`
|
||||
- `MaxPool`
|
||||
- `MeanPool`
|
||||
|
||||
```@docs
|
||||
Flux.outdims
|
||||
```
|
||||
|
@ -5,26 +5,15 @@ These core layers form the foundation of almost all neural networks.
|
||||
```@docs
|
||||
Chain
|
||||
Dense
|
||||
```
|
||||
|
||||
## Convolution and Pooling Layers
|
||||
|
||||
These layers are used to build convolutional neural networks (CNNs).
|
||||
|
||||
```@docs
|
||||
Conv
|
||||
MaxPool
|
||||
GlobalMaxPool
|
||||
MeanPool
|
||||
GlobalMeanPool
|
||||
```
|
||||
|
||||
## Additional Convolution Layers
|
||||
|
||||
```@docs
|
||||
DepthwiseConv
|
||||
ConvTranspose
|
||||
CrossCor
|
||||
SamePad
|
||||
flatten
|
||||
Flux.Zeros
|
||||
Flux.convfilter
|
||||
Flux.depthwiseconvfilter
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
@ -36,57 +25,29 @@ RNN
|
||||
LSTM
|
||||
GRU
|
||||
Flux.Recur
|
||||
Flux.reset!
|
||||
```
|
||||
|
||||
## Other General Purpose Layers
|
||||
These are marginally more obscure than the Basic Layers.
|
||||
But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs).
|
||||
## Activation Functions
|
||||
|
||||
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
||||
|
||||
Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
|
||||
|
||||
```@docs
|
||||
Maxout
|
||||
SkipConnection
|
||||
σ
|
||||
relu
|
||||
leakyrelu
|
||||
elu
|
||||
swish
|
||||
```
|
||||
|
||||
|
||||
## Normalisation & Regularisation
|
||||
|
||||
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||
|
||||
```@docs
|
||||
Flux.normalise
|
||||
BatchNorm
|
||||
Flux.dropout
|
||||
Dropout
|
||||
AlphaDropout
|
||||
LayerNorm
|
||||
InstanceNorm
|
||||
GroupNorm
|
||||
```
|
||||
|
||||
### Testmode
|
||||
|
||||
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `Flux.testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
|
||||
|
||||
```@docs
|
||||
Flux.testmode!
|
||||
trainmode!
|
||||
```
|
||||
|
||||
## Cost Functions
|
||||
```@docs
|
||||
Flux.mae
|
||||
Flux.mse
|
||||
Flux.msle
|
||||
Flux.huber_loss
|
||||
Flux.crossentropy
|
||||
Flux.logitcrossentropy
|
||||
Flux.binarycrossentropy
|
||||
Flux.logitbinarycrossentropy
|
||||
Flux.kldivergence
|
||||
Flux.poisson
|
||||
Flux.hinge
|
||||
Flux.squared_hinge
|
||||
Flux.dice_coeff_loss
|
||||
Flux.tversky_loss
|
||||
BatchNorm
|
||||
Dropout
|
||||
LayerNorm
|
||||
```
|
||||
|
@ -1,61 +0,0 @@
|
||||
# NNlib
|
||||
|
||||
Flux re-exports all of the functions exported by the [NNlib](https://github.com/FluxML/NNlib.jl) package.
|
||||
|
||||
## Activation Functions
|
||||
|
||||
Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
|
||||
|
||||
```@docs
|
||||
NNlib.celu
|
||||
NNlib.elu
|
||||
NNlib.gelu
|
||||
NNlib.hardsigmoid
|
||||
NNlib.hardtanh
|
||||
NNlib.leakyrelu
|
||||
NNlib.lisht
|
||||
NNlib.logcosh
|
||||
NNlib.logsigmoid
|
||||
NNlib.mish
|
||||
NNlib.relu
|
||||
NNlib.relu6
|
||||
NNlib.rrelu
|
||||
NNlib.selu
|
||||
NNlib.sigmoid
|
||||
NNlib.softplus
|
||||
NNlib.softshrink
|
||||
NNlib.softsign
|
||||
NNlib.swish
|
||||
NNlib.tanhshrink
|
||||
NNlib.trelu
|
||||
```
|
||||
|
||||
## Softmax
|
||||
|
||||
```@docs
|
||||
NNlib.softmax
|
||||
NNlib.logsoftmax
|
||||
```
|
||||
|
||||
## Pooling
|
||||
|
||||
```@docs
|
||||
NNlib.maxpool
|
||||
NNlib.meanpool
|
||||
```
|
||||
|
||||
## Convolution
|
||||
|
||||
```@docs
|
||||
NNlib.conv
|
||||
NNlib.depthwiseconv
|
||||
```
|
||||
|
||||
## Batched Operations
|
||||
|
||||
```@docs
|
||||
NNlib.batched_mul
|
||||
NNlib.batched_mul!
|
||||
NNlib.batched_adjoint
|
||||
NNlib.batched_transpose
|
||||
```
|
@ -77,7 +77,7 @@ If you use the `RNN(10, 5)` constructor – as opposed to `RNNCell` – you'll s
|
||||
|
||||
```julia
|
||||
julia> RNN(10, 5)
|
||||
Recur(RNNCell(10, 5, tanh))
|
||||
Recur(RNNCell(Dense(15, 5)))
|
||||
```
|
||||
|
||||
## Sequences
|
||||
@ -101,4 +101,16 @@ m = Chain(LSTM(10, 15), Dense(15, 5))
|
||||
m.(seq)
|
||||
```
|
||||
|
||||
Finally, we can reset the hidden state of the cell back to its initial value using `reset!(m)`.
|
||||
## Truncating Gradients
|
||||
|
||||
By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.
|
||||
|
||||
To avoid this we can *truncate* the gradient calculation, forgetting the history.
|
||||
|
||||
```julia
|
||||
truncate!(m)
|
||||
```
|
||||
|
||||
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
||||
|
||||
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
||||
|
@ -15,8 +15,6 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
|
||||
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
||||
|
||||
```julia
|
||||
using LinearAlgebra
|
||||
|
||||
penalty() = norm(m.W) + norm(m.b)
|
||||
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
||||
```
|
||||
@ -31,7 +29,7 @@ julia> params(m)
|
||||
param([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
julia> sum(norm, params(m))
|
||||
26.01749952921026
|
||||
26.01749952921026 (tracked)
|
||||
```
|
||||
|
||||
Here's a larger example with a multi-layer perceptron.
|
||||
@ -50,21 +48,15 @@ loss(rand(28^2), rand(10))
|
||||
One can also easily add per-layer regularisation via the `activations` function:
|
||||
|
||||
```julia
|
||||
julia> using Flux: activations
|
||||
|
||||
julia> c = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
|
||||
|
||||
julia> activations(c, rand(10))
|
||||
3-element Array{Any,1}:
|
||||
Float32[0.84682214, 0.6704139, 0.42177814, 0.257832, 0.36255655]
|
||||
Float32[0.1501253, 0.073269576]
|
||||
Float32[0.5192045, 0.48079553]
|
||||
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
|
||||
param([0.0330606, -0.456104])
|
||||
param([0.61991, 0.38009])
|
||||
|
||||
julia> sum(norm, ans)
|
||||
2.1166067f0
|
||||
```
|
||||
|
||||
```@docs
|
||||
Flux.activations
|
||||
2.639678767773633 (tracked)
|
||||
```
|
||||
|
@ -1,76 +0,0 @@
|
||||
# Performance Tips
|
||||
|
||||
All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/manual/performance-tips/).
|
||||
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
||||
Below follow some Flux specific tips/reminders.
|
||||
|
||||
## Don't use more precision than you need
|
||||
|
||||
Flux works great with all kinds of number types.
|
||||
But often you do not need to be working with say `Float64` (let alone `BigFloat`).
|
||||
Switching to `Float32` can give you a significant speed up,
|
||||
not because the operations are faster, but because the memory usage is halved.
|
||||
Which means allocations occur much faster.
|
||||
And you use less memory.
|
||||
|
||||
|
||||
## Preserve inputs' types
|
||||
|
||||
Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||
they should also preserve the type of their inputs.
|
||||
|
||||
A very artificial example using an activation function like
|
||||
|
||||
```
|
||||
my_tanh(x) = Float64(tanh(x))
|
||||
```
|
||||
|
||||
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
|
||||
because it results in having to use slow mixed type multiplication in the dense layers.
|
||||
Similar situations can occur in the loss function during backpropagation.
|
||||
|
||||
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
|
||||
you will see a large slow-down.
|
||||
|
||||
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
|
||||
E.g. the following will have run into the same problem as above:
|
||||
|
||||
```
|
||||
leaky_tanh(x) = 0.01*x + tanh(x)
|
||||
```
|
||||
|
||||
While one could change the activation function (e.g. to use `0.01f0*x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
|
||||
```
|
||||
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
|
||||
```
|
||||
|
||||
|
||||
## Evaluate batches as Matrices of features
|
||||
|
||||
While it can sometimes be tempting to process your observations (feature vectors) one at a time
|
||||
e.g.
|
||||
```julia
|
||||
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
|
||||
sum(zip(xs, ys)) do (x, y_target)
|
||||
y_pred = model(x) # evaluate the model
|
||||
return loss(y_pred, y_target)
|
||||
end
|
||||
end
|
||||
```
|
||||
|
||||
It is much faster to concatenate them into a matrix,
|
||||
as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications.
|
||||
The improvement is enough that it is worthwhile allocating new memory to store them contiguously.
|
||||
|
||||
```julia
|
||||
x_batch = reduce(hcat, xs)
|
||||
y_batch = reduce(hcat, ys)
|
||||
...
|
||||
function loss_total(x_batch::Matrix, y_batch::Matrix)
|
||||
y_preds = model(x_batch)
|
||||
sum(loss.(y_preds, y_batch))
|
||||
end
|
||||
```
|
||||
|
||||
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
|
||||
This will avoid the splatting penalty, and will hit the optimised `reduce` method.
|
@ -53,7 +53,7 @@ julia> using Flux
|
||||
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
|
||||
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
|
||||
|
||||
julia> weights = params(model);
|
||||
julia> weights = Tracker.data.(params(model));
|
||||
|
||||
julia> using BSON: @save
|
||||
|
||||
@ -113,6 +113,6 @@ You can even store optimiser state alongside the model, to resume training
|
||||
exactly where you left off.
|
||||
|
||||
```julia
|
||||
opt = ADAM()
|
||||
opt = ADAM(params(model))
|
||||
@save "model-$(now()).bson" model opt
|
||||
```
|
||||
|
@ -3,25 +3,25 @@
|
||||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
||||
|
||||
```julia
|
||||
using Flux
|
||||
using Flux.Tracker
|
||||
|
||||
W = rand(2, 5)
|
||||
b = rand(2)
|
||||
W = param(rand(2, 5))
|
||||
b = param(rand(2))
|
||||
|
||||
predict(x) = (W * x) .+ b
|
||||
predict(x) = W*x .+ b
|
||||
loss(x, y) = sum((predict(x) .- y).^2)
|
||||
|
||||
x, y = rand(5), rand(2) # Dummy data
|
||||
l = loss(x, y) # ~ 3
|
||||
|
||||
θ = Params([W, b])
|
||||
grads = gradient(() -> loss(x, y), θ)
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> loss(x, y), params)
|
||||
```
|
||||
|
||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
||||
```julia
|
||||
using Flux.Optimise: update!
|
||||
using Flux.Tracker: grad, update!
|
||||
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down.
|
||||
opt = Descent(0.1) # Gradient descent with learning rate 0.1
|
||||
|
||||
for p in (W, b)
|
||||
update!(opt, p, grads[p])
|
||||
update!(opt, p, -η * grads[p])
|
||||
end
|
||||
```
|
||||
|
||||
@ -46,110 +46,8 @@ An optimiser `update!` accepts a parameter and a gradient, and updates the param
|
||||
All optimisers return an object that, when passed to `train!`, will update the parameters passed to it.
|
||||
|
||||
```@docs
|
||||
Flux.Optimise.update!
|
||||
Descent
|
||||
SGD
|
||||
Momentum
|
||||
Nesterov
|
||||
RMSProp
|
||||
ADAM
|
||||
RADAM
|
||||
AdaMax
|
||||
ADAGrad
|
||||
ADADelta
|
||||
AMSGrad
|
||||
NADAM
|
||||
ADAMW
|
||||
```
|
||||
|
||||
## Optimiser Interface
|
||||
|
||||
Flux's optimisers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient.
|
||||
|
||||
In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example.
|
||||
|
||||
```julia
|
||||
mutable struct Momentum
|
||||
eta
|
||||
rho
|
||||
velocity
|
||||
end
|
||||
|
||||
Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict())
|
||||
```
|
||||
|
||||
The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked.
|
||||
|
||||
```julia
|
||||
function Flux.Optimise.apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
```
|
||||
|
||||
This is the basic definition of a Momentum update rule given by:
|
||||
|
||||
```math
|
||||
v = ρ * v - η * Δ
|
||||
w = w - v
|
||||
```
|
||||
|
||||
The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser.
|
||||
|
||||
Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully.
|
||||
|
||||
## Composing Optimisers
|
||||
|
||||
Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient
|
||||
that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Flux defines some basic decays including `ExpDecay`, `InvDecay` etc.
|
||||
|
||||
```julia
|
||||
opt = Optimiser(ExpDecay(0.001, 0.1, 1000, 1e-4), Descent())
|
||||
```
|
||||
|
||||
Here we apply exponential decay to the `Descent` optimiser. The defaults of `ExpDecay` say that its learning rate will be decayed every 1000 steps.
|
||||
It is then applied like any optimiser.
|
||||
|
||||
```julia
|
||||
w = randn(10, 10)
|
||||
w1 = randn(10,10)
|
||||
ps = Params([w, w1])
|
||||
|
||||
loss(x) = Flux.mse(w * x, w1 * x)
|
||||
|
||||
loss(rand(10)) # around 9
|
||||
|
||||
for t = 1:10^5
|
||||
θ = Params([w, w1])
|
||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
Flux.Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
|
||||
loss(rand(10)) # around 0.9
|
||||
```
|
||||
|
||||
In this manner it is possible to compose optimisers for some added flexibility.
|
||||
|
||||
## Decays
|
||||
|
||||
Similar to optimisers, Flux also defines some simple decays that can be used in conjunction with other optimisers, or standalone.
|
||||
|
||||
```@docs
|
||||
ExpDecay
|
||||
InvDecay
|
||||
WeightDecay
|
||||
```
|
||||
|
||||
## Gradient Clipping
|
||||
|
||||
Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is
|
||||
|
||||
```julia
|
||||
opt = Optimiser(ClipValue(1e-3), ADAM(1e-3))
|
||||
```
|
||||
|
||||
```@docs
|
||||
ClipValue
|
||||
ClipNorm
|
||||
```
|
@ -1,16 +1,15 @@
|
||||
# Training
|
||||
|
||||
To actually train a model we need four things:
|
||||
To actually train a model we need three things:
|
||||
|
||||
* A *objective function*, that evaluates how well a model is doing given some input data.
|
||||
* The trainable parameters of the model.
|
||||
* A collection of data points that will be provided to the objective function.
|
||||
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
|
||||
|
||||
With these we can call `train!`:
|
||||
With these we can call `Flux.train!`:
|
||||
|
||||
```@docs
|
||||
Flux.Optimise.train!
|
||||
```julia
|
||||
Flux.train!(objective, params, data, opt)
|
||||
```
|
||||
|
||||
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
|
||||
@ -32,17 +31,6 @@ Flux.train!(loss, ps, data, opt)
|
||||
```
|
||||
|
||||
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||
For a list of all built-in loss functions, check out the [layer reference](../models/layers.md).
|
||||
|
||||
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
|
||||
|
||||
## Model parameters
|
||||
|
||||
The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the [basics](../models/basics.md) section it is explained how to create models with such parameters. The second argument of the function `Flux.train!` must be an object containing those parameters, which can be obtained from a model `m` as `params(m)`.
|
||||
|
||||
Such an object contains a reference to the model's parameters, not a copy, such that after their training, the model behaves according to their updated values.
|
||||
|
||||
Handling all the parameters on a layer by layer basis is explained in the [Layer Helpers](../models/basics.md) section. Also, for freezing model parameters, see the [Advanced Usage Guide](../models/advanced.md).
|
||||
|
||||
## Datasets
|
||||
|
||||
@ -59,8 +47,7 @@ data = [(x, y)]
|
||||
```julia
|
||||
data = [(x, y), (x, y), (x, y)]
|
||||
# Or equivalently
|
||||
using IterTools: ncycle
|
||||
data = ncycle([(x, y)], 3)
|
||||
data = Iterators.repeated((x, y), 3)
|
||||
```
|
||||
|
||||
It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
|
||||
@ -71,14 +58,6 @@ ys = [rand( 10), rand( 10), rand( 10)]
|
||||
data = zip(xs, ys)
|
||||
```
|
||||
|
||||
Training data can be conveniently partitioned for mini-batch training using the [`Flux.Data.DataLoader`](@ref) type:
|
||||
|
||||
```julia
|
||||
X = rand(28, 28, 60000)
|
||||
Y = rand(0:9, 60000)
|
||||
data = DataLoader(X, Y, batchsize=128)
|
||||
```
|
||||
|
||||
Note that, by default, `train!` only loops over the data once (a single "epoch").
|
||||
A convenient way to run multiple epochs from the REPL is provided by `@epochs`.
|
||||
|
||||
@ -95,10 +74,6 @@ julia> @epochs 2 Flux.train!(...)
|
||||
# Train for two epochs
|
||||
```
|
||||
|
||||
```@docs
|
||||
Flux.@epochs
|
||||
```
|
||||
|
||||
## Callbacks
|
||||
|
||||
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
||||
@ -118,38 +93,3 @@ evalcb() = @show(loss(test_x, test_y))
|
||||
Flux.train!(objective, ps, data, opt,
|
||||
cb = throttle(evalcb, 5))
|
||||
```
|
||||
|
||||
Calling `Flux.stop()` in a callback will exit the training loop early.
|
||||
|
||||
```julia
|
||||
cb = function ()
|
||||
accuracy() > 0.9 && Flux.stop()
|
||||
end
|
||||
```
|
||||
|
||||
## Custom Training loops
|
||||
|
||||
The `Flux.train!` function can be very convenient, especially for simple problems.
|
||||
Its also very flexible with the use of callbacks.
|
||||
But for some problems its much cleaner to write your own custom training loop.
|
||||
An example follows that works similar to the default `Flux.train` but with no callbacks.
|
||||
You don't need callbacks if you just code the calls to your functions directly into the loop.
|
||||
E.g. in the places marked with comments.
|
||||
|
||||
```julia
|
||||
function my_custom_train!(loss, ps, data, opt)
|
||||
ps = Params(ps)
|
||||
for d in data
|
||||
gs = gradient(ps) do
|
||||
training_loss = loss(d...)
|
||||
# Insert whatever code you want here that needs Training loss, e.g. logging
|
||||
return training_loss
|
||||
end
|
||||
# insert what ever code you want here that needs gradient
|
||||
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
|
||||
update!(opt, ps, gs)
|
||||
# Here you might like to check validation set accuracy, and break out to do early stopping
|
||||
end
|
||||
end
|
||||
```
|
||||
You could simplify this further, for example by hard-coding in the loss function.
|
||||
|
@ -1,49 +0,0 @@
|
||||
# Utility Functions
|
||||
|
||||
Flux contains some utility functions for working with data; these functions
|
||||
help create inputs for your models or batch your dataset.
|
||||
Other functions can be used to initialize your layers or to regularly execute
|
||||
callback functions.
|
||||
|
||||
## Working with Data
|
||||
|
||||
```@docs
|
||||
Flux.unsqueeze
|
||||
Flux.stack
|
||||
Flux.unstack
|
||||
Flux.chunk
|
||||
Flux.frequencies
|
||||
Flux.batch
|
||||
Flux.batchseq
|
||||
Base.rpad(v::AbstractVector, n::Integer, p)
|
||||
```
|
||||
|
||||
## Layer Initialization
|
||||
|
||||
These are primarily useful if you are planning to write your own layers.
|
||||
Flux initializes convolutional layers and recurrent cells with `glorot_uniform`
|
||||
by default.
|
||||
To change the default on an applicable layer, pass the desired function with the
|
||||
`init` keyword. For example:
|
||||
```jldoctest; setup = :(using Flux)
|
||||
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
|
||||
Conv((3, 3), 1=>8, relu)
|
||||
```
|
||||
|
||||
```@docs
|
||||
Flux.glorot_uniform
|
||||
Flux.glorot_normal
|
||||
```
|
||||
|
||||
## Model Abstraction
|
||||
|
||||
```@docs
|
||||
Flux.destructure
|
||||
```
|
||||
|
||||
## Callback Helpers
|
||||
|
||||
```@docs
|
||||
Flux.throttle
|
||||
Flux.stop
|
||||
```
|
@ -14,7 +14,7 @@
|
||||
journal = {arXiv},
|
||||
volume = {abs/11712.03112},
|
||||
year = {2017},
|
||||
url = {https://arxiv.org/abs/1712.03112},
|
||||
url = {http://arxiv.org/abs/1712.03112},
|
||||
}
|
||||
|
||||
@online{MLPL,
|
||||
@ -29,7 +29,7 @@
|
||||
author = {Mike Innes and others},
|
||||
title = {Generic GPU Kernels},
|
||||
year = 2017,
|
||||
url = {https://mikeinnes.github.io/2017/08/24/cudanative.html},
|
||||
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
|
44
src/Flux.jl
44
src/Flux.jl
@ -3,35 +3,30 @@ module Flux
|
||||
# Zero Flux Given
|
||||
|
||||
using Base: tail
|
||||
using Statistics, Random, LinearAlgebra
|
||||
using Zygote, MacroTools, Juno, Reexport
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||
|
||||
export gradient
|
||||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTranspose,
|
||||
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
using .Tracker: data
|
||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
|
||||
ClipValue, ClipNorm
|
||||
|
||||
|
||||
using CuArrays
|
||||
const use_cuda = Ref(false)
|
||||
ADAMW, InvDecay, ExpDecay, WeightDecay
|
||||
|
||||
include("utils.jl")
|
||||
include("zeros.jl")
|
||||
include("onehot.jl")
|
||||
include("functor.jl")
|
||||
include("treelike.jl")
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
@ -41,17 +36,6 @@ include("layers/normalise.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
include("deprecations.jl")
|
||||
|
||||
include("cuda/cuda.jl")
|
||||
|
||||
function __init__()
|
||||
use_cuda[] = CuArrays.functional() # Can be overridden after load with `Flux.use_cuda[] = false`
|
||||
if CuArrays.functional()
|
||||
if !CuArrays.has_cudnn()
|
||||
@warn "CuArrays.jl found cuda, but did not find libcudnn. Some functionality will not be available."
|
||||
end
|
||||
end
|
||||
end
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
@ -2,8 +2,20 @@ module CUDA
|
||||
|
||||
using ..CuArrays
|
||||
|
||||
using CuArrays: CUDNN
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
if !applicable(CuArray{UInt8}, undef, 1)
|
||||
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
if isdefined(CuArrays, :libcudnn_handle)
|
||||
handle() = CuArrays.libcudnn_handle[]
|
||||
else
|
||||
handle() = CuArrays.CUDNN.handle()
|
||||
end
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
else
|
||||
@warn("CUDNN is not installed, some functionality will not be available.")
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -1,8 +1,228 @@
|
||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
import ..Flux: data
|
||||
import CuArrays.CUDNN: batchnorm, ∇batchnorm
|
||||
using LinearAlgebra
|
||||
|
||||
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
end
|
||||
|
||||
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing)
|
||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
|
||||
|
||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||
d = [C_NULL]
|
||||
s = Csize_t[0]
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
||||
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||
desc,handle(),ρ,states,length(states),seed)
|
||||
finalizer(desc) do x
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return desc
|
||||
end
|
||||
|
||||
const BATCHNORM_SPATIAL = 1
|
||||
const BATCHNORM_ACTIVATION = 0
|
||||
const BATCHNORM_MIN_EPS = 1e-5
|
||||
|
||||
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
|
||||
|
||||
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
|
||||
|
||||
mutable struct BNCache
|
||||
mean
|
||||
ivar
|
||||
end
|
||||
|
||||
BNCache() = BNCache(nothing, nothing)
|
||||
|
||||
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
||||
# so reshape a 2D Tensor into 4D
|
||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
|
||||
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
|
||||
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
|
||||
|
||||
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||
y = similar(x)
|
||||
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
|
||||
alpha = alpha, beta = beta, eps = eps, training = training)
|
||||
y
|
||||
end
|
||||
|
||||
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||
momentum; cache = nothing,
|
||||
alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||
dims = _wsize(x)
|
||||
if eps < BATCHNORM_MIN_EPS
|
||||
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||
eps = BATCHNORM_MIN_EPS
|
||||
end
|
||||
xd = TensorDesc(x)
|
||||
yd = TensorDesc(y)
|
||||
gd = TensorDesc(T, dims)
|
||||
|
||||
if training
|
||||
|
||||
if cache !== nothing
|
||||
mean = zeros(CuArray{T}, dims...)
|
||||
ivar = ones(CuArray{T}, dims...)
|
||||
else
|
||||
mean = C_NULL
|
||||
ivar = C_NULL
|
||||
end
|
||||
|
||||
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
||||
Cdouble, Ptr{T}, Ptr{T},
|
||||
Cdouble, Ptr{T}, Ptr{T}),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
xd, x,
|
||||
yd, y,
|
||||
gd, g, b,
|
||||
momentum, running_mean, running_var,
|
||||
eps, mean, ivar)
|
||||
|
||||
if cache !== nothing
|
||||
cache.mean = mean
|
||||
cache.ivar = ivar
|
||||
end
|
||||
else
|
||||
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
||||
Ptr{T}, Ptr{T},
|
||||
Cdouble),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
xd, x,
|
||||
yd, y,
|
||||
gd, g, b,
|
||||
running_mean, running_var,
|
||||
eps)
|
||||
end
|
||||
end
|
||||
|
||||
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
|
||||
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
|
||||
alpha = alpha, beta = beta, training = training)
|
||||
(dg, db, dropdims(dx, dims = (1, 2)))
|
||||
end
|
||||
|
||||
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||
dg = similar(g)
|
||||
db = similar(b)
|
||||
dx = similar(x)
|
||||
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
|
||||
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
||||
(dg, db, dx)
|
||||
end
|
||||
|
||||
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||
momentum; cache = nothing, eps = T(1e-5),
|
||||
alpha = T(1), beta = T(0),
|
||||
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||
if training
|
||||
xd = TensorDesc(x)
|
||||
dyd = TensorDesc(dy)
|
||||
dxd = TensorDesc(dx)
|
||||
gd = TensorDesc(T, _wsize(x))
|
||||
if cache !== nothing
|
||||
mean, ivar = cache.mean, cache.ivar
|
||||
info("mean and ivar are fetched from the cache")
|
||||
else
|
||||
mean, ivar = C_NULL, C_NULL
|
||||
end
|
||||
|
||||
if eps < BATCHNORM_MIN_EPS
|
||||
eps = BATCHNORM_MIN_EPS
|
||||
end
|
||||
|
||||
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
|
||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
|
||||
Cdouble, Ptr{T}, Ptr{T}),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||
xd, x,
|
||||
dyd, dy,
|
||||
dxd, dx,
|
||||
gd, g, dg, db,
|
||||
eps, mean, ivar)
|
||||
else
|
||||
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
|
||||
dx .= dy .* reshape(g, _wsize(x)) .* ivar
|
||||
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
|
||||
db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
|
||||
end
|
||||
end
|
||||
|
||||
# Flux Interface
|
||||
|
||||
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||
|
@ -1,90 +1,325 @@
|
||||
import ..Flux: Flux, relu
|
||||
using CuArrays.CUDAnative
|
||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
using LinearAlgebra
|
||||
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||
const LSTM = 2 # LSTM with no peephole connections
|
||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
||||
|
||||
const LINEAR_INPUT = 0
|
||||
const SKIP_INPUT = 1
|
||||
|
||||
const UNIDIRECTIONAL = 0
|
||||
const BIDIRECTIONAL = 1
|
||||
|
||||
const RNN_ALGO_STANDARD = 0
|
||||
const RNN_ALGO_PERSIST_STATIC = 1
|
||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
|
||||
# param layout:
|
||||
# RNN: [weight, bias] × [input, hidden]
|
||||
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
|
||||
(wx, wh), bias
|
||||
end
|
||||
|
||||
mutable struct RNNDesc{T}
|
||||
mode::Int
|
||||
input::Int
|
||||
hidden::Int
|
||||
params::CuVector{T}
|
||||
weights::NTuple{2,CuMatrix{T}}
|
||||
bias::CuVector{T}
|
||||
ptr::Ptr{Nothing}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
||||
|
||||
function rnnParamSize(T, r, input)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
|
||||
handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||
return Int(size[])÷sizeof(T)
|
||||
end
|
||||
|
||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||
ngates(r::RNNDesc) = ngates(r.mode)
|
||||
|
||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
|
||||
|
||||
dropoutDesc = DropoutDesc(0)
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
||||
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd) do x
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return rd
|
||||
end
|
||||
|
||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
|
||||
handle(), r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(undef, 1)]
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(undef, bytes))
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||
|
||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
|
||||
handle(), r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, reserve=nothing) where T
|
||||
if reserve == nothing
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace))
|
||||
else
|
||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
end
|
||||
end
|
||||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Nothing) = C_NULL, C_NULL
|
||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
end
|
||||
|
||||
# TODO: can we just manipulate strides here?
|
||||
# TODO: should use repmat, but this isn't implemented.
|
||||
hBatch(x::AbstractVector, h::CuVector) = h
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||
h = hBatch(x, h_)
|
||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||
@assert size(x, 1) == rnn.input
|
||||
@assert size(h, 1) == rnn.hidden
|
||||
@assert size(x, 2) == size(h, 2)
|
||||
seqLength = 1
|
||||
xdesc = xDesc(x)
|
||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||
ho = similar(h)
|
||||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train == Val{true} ?
|
||||
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
nothing
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve)
|
||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||
return train == Val{true} ? (reserve, result) : result
|
||||
end
|
||||
|
||||
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
||||
forward(rnn, x, h, c, Val{true})
|
||||
|
||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
||||
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
# Same as above, any more efficient way?
|
||||
dy = dy_ isa Integer ? zero(y) : dy_
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], reserve)
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
backwardData(rnn, y, dy, dho, hx, reserve) =
|
||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
||||
|
||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
||||
Ptr{Nothing}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
||||
Ptr{Nothing}, Csize_t, #ws
|
||||
Ptr{Nothing}, Ptr{T}, #dw
|
||||
Ptr{Nothing}, Csize_t), #rs
|
||||
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zero(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], reserve)
|
||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||
end
|
||||
|
||||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using .CuArrays.CUDAnative
|
||||
using .CuArrays: @cuindex, cudims
|
||||
|
||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
Wi, Wh = d.weights
|
||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||
copy_transpose!(d.bias, Flux.data(m.b))
|
||||
return
|
||||
end
|
||||
|
||||
function RNNDesc(m::CuRNNs{T}) where T
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
(m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
|
||||
m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
|
||||
r = CUDNN.RNNDesc{T}(mode, i, h)
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
m isa CuGRU ? GRU : LSTM
|
||||
r = RNNDesc{T}(mode, i, h)
|
||||
return r
|
||||
end
|
||||
|
||||
const descs = WeakKeyDict()
|
||||
|
||||
function desc(rnn)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
|
||||
CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
||||
copyparams!(rnn, d)
|
||||
return d
|
||||
end
|
||||
|
||||
import Zygote
|
||||
using Zygote: @adjoint
|
||||
import Flux.Tracker
|
||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
||||
|
||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
y, h′ = CUDNN.forward(desc(m), x, h)
|
||||
return h′, y
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
y, h′ = CUDNN.forward(desc(m), x, h)
|
||||
return h′, y
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2])
|
||||
return (h′, c′), y
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
end
|
||||
|
||||
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
coerce_cuda(x::Union{CuArray,Nothing}) = x
|
||||
coerce_cuda(x::Tuple) = coerce_cuda.(x)
|
||||
|
||||
coerce_cuda(x::AbstractArray) = x .+ CuArrays.fill(0)
|
||||
|
||||
function struct_grad!(cx::Zygote.Context, x, x̄)
|
||||
for f in fieldnames(typeof(x))
|
||||
Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f))
|
||||
end
|
||||
dx = Zygote.grad_mut(cx, x)
|
||||
dx[] = Zygote.accum(dx[], x̄)
|
||||
return dx
|
||||
end
|
||||
|
||||
for RNN in (CuRNN, CuGRU)
|
||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
(y, ho), back = CUDNN.pullback(desc(m), x, h)
|
||||
(ho, y), function (Δ)
|
||||
dho, dy = coerce_cuda(Δ) # Support FillArrays etc.
|
||||
m̄ = back(dy, dho)
|
||||
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing))
|
||||
(dm, unbroadcast(h, m̄.h), m̄.x)
|
||||
end
|
||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
||||
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
(y, ho, co), back = CUDNN.pullback(desc(m), x, h, c)
|
||||
((ho, co), y), function (Δ)
|
||||
dhc, dy = coerce_cuda(Δ) # Support FillArrays etc.
|
||||
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||
m̄ = back(dy, dho, dco)
|
||||
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing))
|
||||
(dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x)
|
||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
@ -1,37 +1,15 @@
|
||||
module Data
|
||||
|
||||
import ..Flux
|
||||
import SHA
|
||||
|
||||
using Random: shuffle!
|
||||
using Base: @propagate_inbounds
|
||||
|
||||
export CMUDict, cmudict
|
||||
|
||||
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
|
||||
|
||||
function download_and_verify(url, path, hash)
|
||||
tmppath = tempname()
|
||||
download(url, tmppath)
|
||||
hash_download = open(tmppath) do f
|
||||
bytes2hex(SHA.sha256(f))
|
||||
end
|
||||
if hash_download !== hash
|
||||
msg = "Hash Mismatch!\n"
|
||||
msg *= " Expected sha256: $hash\n"
|
||||
msg *= " Calculated sha256: $hash_download"
|
||||
error(msg)
|
||||
end
|
||||
mv(tmppath, path; force=true)
|
||||
end
|
||||
|
||||
function __init__()
|
||||
mkpath(deps())
|
||||
end
|
||||
|
||||
include("dataloader.jl")
|
||||
export DataLoader
|
||||
|
||||
include("mnist.jl")
|
||||
export MNIST
|
||||
|
||||
@ -45,12 +23,4 @@ include("tree.jl")
|
||||
include("sentiment.jl")
|
||||
using .Sentiment
|
||||
|
||||
include("iris.jl")
|
||||
export Iris
|
||||
|
||||
include("housing.jl")
|
||||
export Housing
|
||||
|
||||
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
|
||||
|
||||
end
|
||||
|
@ -2,57 +2,38 @@ module CMUDict
|
||||
|
||||
export cmudict
|
||||
|
||||
using ..Data: deps, download_and_verify
|
||||
using ..Data: deps
|
||||
|
||||
const version = "0.7b"
|
||||
const cache_prefix = "https://cache.julialang.org"
|
||||
|
||||
function load()
|
||||
suffixes_and_hashes = [("" , "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"),
|
||||
(".phones" , "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"),
|
||||
(".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027")]
|
||||
suffixes = ["", ".phones", ".symbols"]
|
||||
if isdir(deps("cmudict"))
|
||||
if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes)
|
||||
if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes)
|
||||
return
|
||||
end
|
||||
end
|
||||
@info "Downloading CMUDict dataset"
|
||||
mkpath(deps("cmudict"))
|
||||
for (x, hash) in suffixes_and_hashes
|
||||
download_and_verify("$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||
deps("cmudict", "cmudict$x"), hash)
|
||||
for x in suffixes
|
||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||
deps("cmudict", "cmudict$x"))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
phones()
|
||||
|
||||
Return a `Vector` containing the phones used in the CMU Pronouncing Dictionary.
|
||||
"""
|
||||
function phones()
|
||||
load()
|
||||
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
||||
"\n", keepempty = false), "\t")))
|
||||
end
|
||||
|
||||
"""
|
||||
symbols()
|
||||
|
||||
Return a `Vector` containing the symbols used in the CMU Pronouncing Dictionary.
|
||||
A symbol is a phone with optional auxiliary symbols, indicating for example the
|
||||
amount of stress on the phone.
|
||||
"""
|
||||
function symbols()
|
||||
load()
|
||||
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
||||
"\n", keepempty = false))
|
||||
end
|
||||
|
||||
"""
|
||||
rawdict()
|
||||
|
||||
Return the unfiltered CMU Pronouncing Dictionary.
|
||||
"""
|
||||
function rawdict()
|
||||
load()
|
||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||
@ -61,14 +42,6 @@ end
|
||||
|
||||
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
||||
|
||||
"""
|
||||
cmudict()
|
||||
|
||||
Return a filtered CMU Pronouncing Dictionary.
|
||||
|
||||
It is filtered so each word contains only ASCII characters and a combination of
|
||||
word characters (as determined by the regex engine using `\\w`), '-' and '.'.
|
||||
"""
|
||||
cmudict() = filter(p -> validword(p.first), rawdict())
|
||||
|
||||
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
||||
|
@ -1,110 +0,0 @@
|
||||
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
||||
|
||||
struct DataLoader{D}
|
||||
data::D
|
||||
batchsize::Int
|
||||
nobs::Int
|
||||
partial::Bool
|
||||
imax::Int
|
||||
indices::Vector{Int}
|
||||
shuffle::Bool
|
||||
end
|
||||
|
||||
"""
|
||||
DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||
|
||||
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
||||
(except possibly the last one).
|
||||
|
||||
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
|
||||
The last dimension in each tensor is considered to be the observation dimension.
|
||||
|
||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||
|
||||
The original data is preserved in the `data` field of the DataLoader.
|
||||
|
||||
Usage example:
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||
# iterate over 50 mini-batches of size 2
|
||||
for x in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
...
|
||||
end
|
||||
|
||||
train_loader.data # original dataset
|
||||
|
||||
# similar, but yielding tuples
|
||||
train_loader = DataLoader((Xtrain,), batchsize=2)
|
||||
for (x,) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
...
|
||||
end
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
Ytrain = rand(100)
|
||||
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
|
||||
for epoch in 1:100
|
||||
for (x, y) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
@assert size(y) == (2,)
|
||||
...
|
||||
end
|
||||
end
|
||||
|
||||
# train for 10 epochs
|
||||
using IterTools: ncycle
|
||||
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||
|
||||
# can use NamedTuple to name tensors
|
||||
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
|
||||
for datum in train_loader
|
||||
@assert size(datum.images) == (10, 2)
|
||||
@assert size(datum.labels) == (2,)
|
||||
end
|
||||
"""
|
||||
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
|
||||
|
||||
n = _nobs(data)
|
||||
if n < batchsize
|
||||
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
|
||||
batchsize = n
|
||||
end
|
||||
imax = partial ? n : n - batchsize + 1
|
||||
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
|
||||
end
|
||||
|
||||
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
||||
i >= d.imax && return nothing
|
||||
if d.shuffle && i == 0
|
||||
shuffle!(d.indices)
|
||||
end
|
||||
nexti = min(i + d.batchsize, d.nobs)
|
||||
ids = d.indices[i+1:nexti]
|
||||
batch = _getobs(d.data, ids)
|
||||
return (batch, nexti)
|
||||
end
|
||||
|
||||
function Base.length(d::DataLoader)
|
||||
n = d.nobs / d.batchsize
|
||||
d.partial ? ceil(Int,n) : floor(Int,n)
|
||||
end
|
||||
|
||||
_nobs(data::AbstractArray) = size(data)[end]
|
||||
|
||||
function _nobs(data::Union{Tuple, NamedTuple})
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
n = _nobs(data[1])
|
||||
if !all(x -> _nobs(x) == n, Base.tail(data))
|
||||
throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
end
|
||||
return n
|
||||
end
|
||||
|
||||
_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
|
||||
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
|
||||
|
||||
Base.eltype(::DataLoader{D}) where D = D
|
@ -1,20 +1,19 @@
|
||||
module FashionMNIST
|
||||
|
||||
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
|
||||
using ..Data: download_and_verify
|
||||
|
||||
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
|
||||
|
||||
function load()
|
||||
mkpath(dir)
|
||||
cd(dir) do
|
||||
for (file, hash) in [("train-images-idx3-ubyte", "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"),
|
||||
("train-labels-idx1-ubyte", "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"),
|
||||
("t10k-images-idx3-ubyte" , "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"),
|
||||
("t10k-labels-idx1-ubyte" , "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5")]
|
||||
for file in ["train-images-idx3-ubyte",
|
||||
"train-labels-idx1-ubyte",
|
||||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
isfile(file) && continue
|
||||
@info "Downloading Fashion-MNIST dataset"
|
||||
download_and_verify("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz", hash)
|
||||
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz")
|
||||
open(file, "w") do io
|
||||
write(io, gzopen(read, "$file.gz"))
|
||||
end
|
||||
@ -33,10 +32,9 @@ const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
||||
|
||||
Load the Fashion-MNIST images.
|
||||
|
||||
Each image is a 28×28 array of `Gray` colour values
|
||||
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
|
||||
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
||||
|
||||
Return the 60,000 training images by default; pass `:test` to retrieve the
|
||||
Returns the 60,000 training images by default; pass `:test` to retreive the
|
||||
10,000 test images.
|
||||
"""
|
||||
function images(set = :train)
|
||||
@ -50,10 +48,10 @@ end
|
||||
labels()
|
||||
labels(:test)
|
||||
|
||||
Load the labels corresponding to each of the images returned from [`images()`](@ref).
|
||||
Load the labels corresponding to each of the images returned from `images()`.
|
||||
Each label is a number from 0-9.
|
||||
|
||||
Return the 60,000 training labels by default; pass `:test` to retrieve the
|
||||
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
||||
10,000 test labels.
|
||||
"""
|
||||
function labels(set = :train)
|
||||
|
@ -1,136 +0,0 @@
|
||||
"""
|
||||
1. Title: Boston Housing Data
|
||||
|
||||
2. Sources:
|
||||
(a) Origin: This dataset was taken from the StatLib library which is
|
||||
maintained at Carnegie Mellon University.
|
||||
(b) Creator: Harrison, D. and Rubinfeld, D.L. 'Hedonic prices and the
|
||||
demand for clean air', J. Environ. Economics & Management,
|
||||
vol.5, 81-102, 1978.
|
||||
(c) Date: July 7, 1993
|
||||
|
||||
3. Number of Instances: 506
|
||||
|
||||
4. Number of Attributes: 13 continuous attributes (including "class"
|
||||
attribute "MEDV"), 1 binary-valued attribute.
|
||||
|
||||
5. Attribute Information:
|
||||
|
||||
1. CRIM per capita crime rate by town
|
||||
2. ZN proportion of residential land zoned for lots over
|
||||
25,000 sq.ft.
|
||||
3. INDUS proportion of non-retail business acres per town
|
||||
4. CHAS Charles River dummy variable (= 1 if tract bounds
|
||||
river; 0 otherwise)
|
||||
5. NOX nitric oxides concentration (parts per 10 million)
|
||||
6. RM average number of rooms per dwelling
|
||||
7. AGE proportion of owner-occupied units built prior to 1940
|
||||
8. DIS weighted distances to five Boston employment centres
|
||||
9. RAD index of accessibility to radial highways
|
||||
10. TAX full-value property-tax rate per 10,000 dollars
|
||||
11. PTRATIO pupil-teacher ratio by town
|
||||
12. B 1000(Bk - 0.63)^2 where Bk is the proportion of blacks
|
||||
by town
|
||||
13. LSTAT % lower status of the population
|
||||
14. MEDV Median value of owner-occupied homes in 1000's of dollars
|
||||
|
||||
Downloaded From: https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data
|
||||
|
||||
"""
|
||||
module Housing
|
||||
|
||||
using DelimitedFiles
|
||||
using ..Data: deps, download_and_verify
|
||||
|
||||
#Uncomment if package exists
|
||||
#const cache_prefix = "https://cache.julialang.org/"
|
||||
const cache_prefix = ""
|
||||
|
||||
function load()
|
||||
isfile(deps("housing.data")) && return
|
||||
|
||||
@info "Downloading the Boston housing Dataset"
|
||||
download_and_verify("$(cache_prefix)http://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data",
|
||||
deps("housing.data"),
|
||||
"baadf72995725d76efe787b664e1f083388c79ba21ef9a7990d87f774184735a")
|
||||
|
||||
#@info "Download complete. Working on the files"
|
||||
path = deps()
|
||||
isfile(deps("housing.data")) && touch(joinpath(path, "tempfile.data"))
|
||||
open(joinpath(path, "tempfile.data"), "a") do fout
|
||||
open(deps("housing.data"), "r") do fin
|
||||
for line in eachline(fin)
|
||||
line = replace(lstrip(line), r" +" => s",")
|
||||
println(fout, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
mv(joinpath(path, "tempfile.data"), deps("housing.data"), force=true)
|
||||
end
|
||||
|
||||
"""
|
||||
Gets the targets for the Boston housing dataset, a 506 element array listing the targets for each example
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux
|
||||
|
||||
julia> target = Flux.Data.Housing.targets()
|
||||
|
||||
julia> summary(target)
|
||||
506×1 Array{Float64,2}
|
||||
|
||||
julia> target[1]
|
||||
24.0
|
||||
|
||||
"""
|
||||
function targets()
|
||||
load()
|
||||
housing = readdlm(deps("housing.data"), ',')
|
||||
reshape(Vector{Float64}(housing[1:end,end]), (506, 1))
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
Gets the names of the features provided in the dataset
|
||||
|
||||
"""
|
||||
function feature_names()
|
||||
["crim","zn","indus","chas","nox","rm","age","dis","rad","tax","ptratio","b","lstat"]
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
Gets the features of the Boston Housing Dataset. This is a 506x13 Matrix of Float64 datatypes.
|
||||
The values are in the order ["crim","zn","indus","chas","nox","rm","age","dis","rad","tax","ptratio","b","lstat"].
|
||||
It has 506 examples.
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux
|
||||
|
||||
julia> features = Flux.Data.Housing.features()
|
||||
|
||||
julia> summary(features)
|
||||
506×13 Array{Float64,2}
|
||||
|
||||
julia> features[1, :]
|
||||
13-element Array{Float64,1}:
|
||||
0.00632
|
||||
18.0
|
||||
2.31
|
||||
0.0
|
||||
0.538
|
||||
⋮
|
||||
296.0
|
||||
15.3
|
||||
396.9
|
||||
4.98
|
||||
|
||||
"""
|
||||
function features()
|
||||
load()
|
||||
housing = readdlm(deps("housing.data"), ',')
|
||||
Matrix{Float64}(housing[1:end, 1:13])
|
||||
end
|
||||
|
||||
|
||||
end
|
@ -1,78 +0,0 @@
|
||||
"""
|
||||
Fisher's classic iris dataset.
|
||||
|
||||
Measurements from 3 different species of iris: setosa, versicolor and
|
||||
virginica. There are 50 examples of each species.
|
||||
|
||||
There are 4 measurements for each example: sepal length, sepal width,
|
||||
petal length and petal width. The measurements are in centimeters.
|
||||
|
||||
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
||||
"""
|
||||
module Iris
|
||||
|
||||
using DelimitedFiles
|
||||
using ..Data: deps, download_and_verify
|
||||
|
||||
# Uncomment if the iris.data file is cached to cache.julialang.org.
|
||||
const cache_prefix = "https://cache.julialang.org/"
|
||||
|
||||
function load()
|
||||
isfile(deps("iris.data")) && return
|
||||
|
||||
@info "Downloading iris dataset."
|
||||
download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
|
||||
deps("iris.data"),
|
||||
"6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0")
|
||||
end
|
||||
|
||||
"""
|
||||
labels()
|
||||
|
||||
Get the labels of the iris dataset, a 150 element array of strings listing the
|
||||
species of each example.
|
||||
|
||||
```jldoctest; setup = :(Flux.Data.Iris.load())
|
||||
julia> labels = Flux.Data.Iris.labels();
|
||||
|
||||
julia> summary(labels)
|
||||
"150-element Array{String,1}"
|
||||
|
||||
julia> labels[1]
|
||||
"Iris-setosa"
|
||||
```
|
||||
"""
|
||||
function labels()
|
||||
load()
|
||||
iris = readdlm(deps("iris.data"), ',')
|
||||
Vector{String}(iris[1:end, end])
|
||||
end
|
||||
|
||||
"""
|
||||
features()
|
||||
|
||||
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
||||
elements. It has a row for each feature (sepal length, sepal width,
|
||||
petal length, petal width) and a column for each example.
|
||||
|
||||
```jldoctest; setup = :(Flux.Data.Iris.load())
|
||||
julia> features = Flux.Data.Iris.features();
|
||||
|
||||
julia> summary(features)
|
||||
"4×150 Array{Float64,2}"
|
||||
|
||||
julia> features[:, 1]
|
||||
4-element Array{Float64,1}:
|
||||
5.1
|
||||
3.5
|
||||
1.4
|
||||
0.2
|
||||
```
|
||||
"""
|
||||
function features()
|
||||
load()
|
||||
iris = readdlm(deps("iris.data"), ',')
|
||||
Matrix{Float64}(iris[1:end, 1:4]')
|
||||
end
|
||||
|
||||
end
|
@ -1,7 +1,6 @@
|
||||
module MNIST
|
||||
|
||||
using CodecZlib, Colors
|
||||
using ..Data: download_and_verify
|
||||
|
||||
const Gray = Colors.Gray{Colors.N0f8}
|
||||
|
||||
@ -16,13 +15,13 @@ end
|
||||
function load()
|
||||
mkpath(dir)
|
||||
cd(dir) do
|
||||
for (file, hash) in [("train-images-idx3-ubyte", "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"),
|
||||
("train-labels-idx1-ubyte", "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"),
|
||||
("t10k-images-idx3-ubyte" , "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"),
|
||||
("t10k-labels-idx1-ubyte" , "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6")]
|
||||
for file in ["train-images-idx3-ubyte",
|
||||
"train-labels-idx1-ubyte",
|
||||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
isfile(file) && continue
|
||||
@info "Downloading MNIST dataset"
|
||||
download_and_verify("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz", hash)
|
||||
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||
open(file, "w") do io
|
||||
write(io, gzopen(read, "$file.gz"))
|
||||
end
|
||||
@ -83,10 +82,9 @@ getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
|
||||
|
||||
Load the MNIST images.
|
||||
|
||||
Each image is a 28×28 array of `Gray` colour values
|
||||
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
|
||||
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
||||
|
||||
Return the 60,000 training images by default; pass `:test` to retrieve the
|
||||
Returns the 60,000 training images by default; pass `:test` to retreive the
|
||||
10,000 test images.
|
||||
"""
|
||||
function images(set = :train)
|
||||
@ -100,10 +98,10 @@ end
|
||||
labels()
|
||||
labels(:test)
|
||||
|
||||
Load the labels corresponding to each of the images returned from [`images()`](@ref).
|
||||
Load the labels corresponding to each of the images returned from `images()`.
|
||||
Each label is a number from 0-9.
|
||||
|
||||
Return the 60,000 training labels by default; pass `:test` to retrieve the
|
||||
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
||||
10,000 test labels.
|
||||
"""
|
||||
function labels(set = :train)
|
||||
|
@ -1,14 +1,13 @@
|
||||
"Stanford Sentiment Treebank dataset."
|
||||
module Sentiment
|
||||
|
||||
using ZipFile
|
||||
using ..Data: deps, download_and_verify
|
||||
using ..Data: deps
|
||||
|
||||
function load()
|
||||
isfile(deps("sentiment.zip")) && return
|
||||
@info "Downloading sentiment treebank dataset"
|
||||
download_and_verify("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"), "5c613a4f673fc74097d523a2c83f38e0cc462984d847b82c7aaf36b01cbbbfcc")
|
||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
end
|
||||
|
||||
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
||||
@ -40,28 +39,8 @@ function gettrees(name)
|
||||
return parsetree.(ss)
|
||||
end
|
||||
|
||||
"""
|
||||
train()
|
||||
|
||||
Return the train split of the Stanford Sentiment Treebank.
|
||||
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
|
||||
"""
|
||||
train() = gettrees("train")
|
||||
|
||||
"""
|
||||
test()
|
||||
|
||||
Return the test split of the Stanford Sentiment Treebank.
|
||||
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
|
||||
"""
|
||||
test() = gettrees("test")
|
||||
|
||||
"""
|
||||
dev()
|
||||
|
||||
Return the dev split of the Stanford Sentiment Treebank.
|
||||
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
|
||||
"""
|
||||
dev() = gettrees("dev")
|
||||
|
||||
end
|
||||
|
@ -1,2 +0,0 @@
|
||||
@deprecate param(x) x
|
||||
@deprecate data(x) x
|
@ -1,82 +0,0 @@
|
||||
import Adapt: adapt, adapt_storage
|
||||
using Zygote: IdSet
|
||||
import Functors: @functor, functor, fmap
|
||||
|
||||
trainable(m) = functor(m)[1]
|
||||
|
||||
"""
|
||||
testmode!(m, mode = true)
|
||||
|
||||
Set a layer or model's test mode (see below).
|
||||
Using `:auto` mode will treat any gradient computation as training.
|
||||
|
||||
_Note_: if you manually set a model into test mode, you need to manually place
|
||||
it back into train mode during training phase.
|
||||
|
||||
Possible values include:
|
||||
- `false` for training
|
||||
- `true` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
testmode!(m, mode = true) = m
|
||||
|
||||
"""
|
||||
trainmode!(m, mode = true)
|
||||
|
||||
Set a layer of model's train mode (see below).
|
||||
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)`).
|
||||
|
||||
_Note_: if you manually set a model into train mode, you need to manually place
|
||||
it into test mode during testing phase.
|
||||
|
||||
Possible values include:
|
||||
- `true` for training
|
||||
- `false` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
|
||||
|
||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||
|
||||
function params!(p::Params, x, seen = IdSet())
|
||||
x in seen && return
|
||||
push!(seen, x)
|
||||
for child in trainable(x)
|
||||
params!(p, child, seen)
|
||||
end
|
||||
end
|
||||
|
||||
function params(m...)
|
||||
ps = Params()
|
||||
params!(ps, m)
|
||||
return ps
|
||||
end
|
||||
|
||||
# Deprecated stuff
|
||||
macro treelike(args...)
|
||||
functorm(args...)
|
||||
end
|
||||
mapleaves(f, x) = fmap(f, x)
|
||||
|
||||
function loadparams!(m, xs)
|
||||
for (p, x) in zip(params(m), xs)
|
||||
size(p) == size(x) ||
|
||||
error("Expected param size $(size(p)), got $(size(x))")
|
||||
copyto!(p, x)
|
||||
end
|
||||
end
|
||||
|
||||
# CPU/GPU movement conveniences
|
||||
|
||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||
|
||||
gpu(x) = use_cuda[] ? fmap(CuArrays.cu, x) : x
|
||||
|
||||
# Precision
|
||||
|
||||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
||||
|
||||
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
|
||||
|
||||
f32(m) = paramtype(Float32, m)
|
||||
f64(m) = paramtype(Float64, m)
|
@ -4,33 +4,28 @@
|
||||
Chain multiple layers / functions together, so that they are called in sequence
|
||||
on a given input.
|
||||
|
||||
```julia
|
||||
m = Chain(x -> x^2, x -> x+1)
|
||||
m(5) == 26
|
||||
|
||||
m = Chain(Dense(10, 5), Dense(5, 2))
|
||||
x = rand(10)
|
||||
m(x) == m[2](m[1](x))
|
||||
```
|
||||
|
||||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||
`m[1:3](x)` will calculate the output of the first three layers.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> m = Chain(x -> x^2, x -> x+1);
|
||||
|
||||
julia> m(5) == 26
|
||||
true
|
||||
|
||||
julia> m = Chain(Dense(10, 5), Dense(5, 2));
|
||||
|
||||
julia> x = rand(10);
|
||||
|
||||
julia> m(x) == m[2](m[1](x))
|
||||
true
|
||||
```
|
||||
"""
|
||||
struct Chain{T<:Tuple}
|
||||
layers::T
|
||||
Chain(xs...) = new{typeof(xs)}(xs)
|
||||
end
|
||||
|
||||
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||
Base.iterate, Base.lastindex
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex
|
||||
@forward Chain.layers Base.iterate
|
||||
|
||||
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
|
||||
applychain(::Tuple{}, x) = x
|
||||
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||
@ -39,70 +34,35 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m)
|
||||
|
||||
function Base.show(io::IO, c::Chain)
|
||||
print(io, "Chain(")
|
||||
join(io, c.layers, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
outdims(c::Chain, isize)
|
||||
|
||||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
|
||||
```julia
|
||||
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||
outdims(m, (10, 10)) == (6, 6)
|
||||
```
|
||||
"""
|
||||
outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(isize)
|
||||
|
||||
# This is a temporary and naive implementation
|
||||
# it might be replaced in the future for better performance
|
||||
# see issue https://github.com/FluxML/Flux.jl/issues/702
|
||||
# Johnny Chen -- @johnnychen94
|
||||
# only slightly changed to better handle interaction with Zygote @dsweber2
|
||||
"""
|
||||
activations(c::Chain, input)
|
||||
|
||||
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||
"""
|
||||
function activations(c::Chain, input)
|
||||
extraChain(c.layers, input)
|
||||
end
|
||||
|
||||
function extraChain(fs::Tuple, x)
|
||||
res = first(fs)(x)
|
||||
return (res, extraChain(Base.tail(fs), res)...)
|
||||
end
|
||||
|
||||
extraChain(::Tuple{}, x) = ()
|
||||
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
||||
Create a traditional `Dense` layer with parameters `W` and `b`.
|
||||
Creates a traditional `Dense` layer with parameters `W` and `b`.
|
||||
|
||||
y = σ.(W * x .+ b)
|
||||
|
||||
The input `x` must be a vector of length `in`, or a batch of vectors represented
|
||||
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
|
||||
|
||||
# Examples
|
||||
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||
```julia
|
||||
julia> d = Dense(5, 2)
|
||||
Dense(5, 2)
|
||||
|
||||
julia> d(rand(5))
|
||||
2-element Array{Float32,1}:
|
||||
-0.16210233
|
||||
0.12311903```
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
0.00257447
|
||||
-0.00449443
|
||||
```
|
||||
"""
|
||||
struct Dense{F,S<:AbstractArray,T<:AbstractArray}
|
||||
struct Dense{F,S,T}
|
||||
W::S
|
||||
b::T
|
||||
σ::F
|
||||
@ -112,10 +72,10 @@ Dense(W, b) = Dense(W, b, identity)
|
||||
|
||||
function Dense(in::Integer, out::Integer, σ = identity;
|
||||
initW = glorot_uniform, initb = zeros)
|
||||
return Dense(initW(out, in), initb(out), σ)
|
||||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||
end
|
||||
|
||||
@functor Dense
|
||||
@treelike Dense
|
||||
|
||||
function (a::Dense)(x::AbstractArray)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
@ -128,31 +88,10 @@ function Base.show(io::IO, l::Dense)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
# Try to avoid hitting generic matmul in some simple cases
|
||||
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
outdims(l::Dense, isize)
|
||||
|
||||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
|
||||
```julia
|
||||
m = Dense(10, 5)
|
||||
outdims(m, (5, 2)) == (5,)
|
||||
outdims(m, (10,)) == (5,)
|
||||
```
|
||||
"""
|
||||
outdims(l::Dense, isize) = (size(l.W)[1],)
|
||||
|
||||
"""
|
||||
Diagonal(in::Integer)
|
||||
|
||||
Create an element-wise linear transformation layer with learnable
|
||||
Creates an element-wise linear transformation layer with learnable
|
||||
vectors `α` and `β`:
|
||||
|
||||
y = α .* x .+ β
|
||||
@ -165,9 +104,9 @@ struct Diagonal{T}
|
||||
end
|
||||
|
||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(initα(in), initβ(in))
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
|
||||
@functor Diagonal
|
||||
@treelike Diagonal
|
||||
|
||||
function (a::Diagonal)(x)
|
||||
α, β = a.α, a.β
|
||||
@ -178,83 +117,10 @@ function Base.show(io::IO, l::Diagonal)
|
||||
print(io, "Diagonal(", length(l.α), ")")
|
||||
end
|
||||
|
||||
outdims(l::Diagonal, isize) = (length(l.α),)
|
||||
# Try to avoid hitting generic matmul in some simple cases
|
||||
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
"""
|
||||
Maxout(over)
|
||||
|
||||
The [Maxout](https://arxiv.org/pdf/1302.4389.pdf) layer has a number of
|
||||
internal layers which all receive the same input. It returns the elementwise
|
||||
maximum of the internal layers' outputs.
|
||||
|
||||
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
||||
"""
|
||||
struct Maxout{FS<:Tuple}
|
||||
over::FS
|
||||
end
|
||||
|
||||
"""
|
||||
Maxout(f, n_alts)
|
||||
|
||||
Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
|
||||
The function takes no arguments and should return some callable layer.
|
||||
Conventionally, this is a linear dense layer.
|
||||
|
||||
# Examples
|
||||
|
||||
This constructs a `Maxout` layer over 4 internal dense linear layers, each
|
||||
identical in structure (784 inputs, 128 outputs):
|
||||
```julia
|
||||
insize = 784
|
||||
outsize = 128
|
||||
Maxout(()->Dense(insize, outsize), 4)
|
||||
```
|
||||
"""
|
||||
function Maxout(f, n_alts)
|
||||
over = Tuple(f() for _ in 1:n_alts)
|
||||
return Maxout(over)
|
||||
end
|
||||
|
||||
@functor Maxout
|
||||
|
||||
function (mo::Maxout)(input::AbstractArray)
|
||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||
end
|
||||
|
||||
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
|
||||
|
||||
"""
|
||||
SkipConnection(layer, connection)
|
||||
|
||||
Create a skip connection which consists of a layer or `Chain` of consecutive
|
||||
layers and a shortcut connection linking the block's input to the output
|
||||
through a user-supplied 2-argument callable. The first argument to the callable
|
||||
will be propagated through the given `layer` while the second is the unchanged,
|
||||
"skipped" input.
|
||||
|
||||
The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`,
|
||||
and requires the output of the layers to be the same shape as the input.
|
||||
Here is a more complicated example:
|
||||
```julia
|
||||
m = Conv((3,3), 4=>7, pad=(1,1))
|
||||
x = ones(5,5,4,10);
|
||||
size(m(x)) == (5, 5, 7, 10)
|
||||
|
||||
sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
|
||||
size(sm(x)) == (5, 5, 11, 10)
|
||||
```
|
||||
"""
|
||||
struct SkipConnection
|
||||
layers
|
||||
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
||||
end
|
||||
|
||||
@functor SkipConnection
|
||||
|
||||
function (skip::SkipConnection)(input)
|
||||
skip.connection(skip.layers(input), input)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, b::SkipConnection)
|
||||
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
|
||||
end
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
@ -1,140 +1,47 @@
|
||||
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
|
||||
using NNlib: conv, depthwiseconv
|
||||
|
||||
# pad dims of x with dims of y until ndims(x) == ndims(y)
|
||||
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)
|
||||
|
||||
_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end])
|
||||
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
||||
"""
|
||||
SamePad
|
||||
Conv(size, in=>out)
|
||||
Conv(size, in=>out, relu)
|
||||
|
||||
Padding for convolutional layers will be calculated so that outputshape == inputshape when stride = 1.
|
||||
|
||||
For stride > 1 the output shape depends on the type of convolution layer.
|
||||
"""
|
||||
struct SamePad end
|
||||
|
||||
calc_padding(pad, k::NTuple{N,T}, dilation, stride) where {T,N}= expand(Val(2*N), pad)
|
||||
function calc_padding(::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T}
|
||||
#Ref: "A guide to convolution arithmetic for deep learning" https://arxiv.org/pdf/1603.07285
|
||||
|
||||
# Effective kernel size, including dilation
|
||||
k_eff = @. k + (k - 1) * (dilation - 1)
|
||||
# How much total padding needs to be applied?
|
||||
pad_amt = @. k_eff - 1
|
||||
# In case amount of padding is odd we need to apply different amounts to each side.
|
||||
return Tuple(mapfoldl(i -> [ceil(Int, i/2), floor(Int, i/2)], vcat, pad_amt))
|
||||
end
|
||||
|
||||
"""
|
||||
Conv(filter, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
|
||||
filter = (2,2)
|
||||
in = 1
|
||||
out = 16
|
||||
Conv((2, 2), 1=>16, relu)
|
||||
|
||||
Standard convolutional layer. `filter` should be a tuple like `(2, 2)`.
|
||||
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
|
||||
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
|
||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
|
||||
|
||||
# Examples
|
||||
|
||||
Apply a `Conv` layer to a 1-channel input using a 2×2 window filter size, giving us a
|
||||
16-channel output. Output is activated with ReLU.
|
||||
```julia
|
||||
filter = (2,2)
|
||||
in = 1
|
||||
out = 16
|
||||
Conv(filter, in => out, relu)
|
||||
```
|
||||
"""
|
||||
struct Conv{N,M,F,A,V}
|
||||
struct Conv{N,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
pad::NTuple{N,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
"""
|
||||
Conv(weight::AbstractArray, bias::AbstractArray)
|
||||
Conv(weight::AbstractArray, bias::AbstractArray, activation)
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
|
||||
Constructs the convolutional layer with user defined weight and bias arrays.
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
|
||||
There is also a keyword-only constuctor available for all convoultional
|
||||
layers.
|
||||
|
||||
```julia
|
||||
weight = rand(Float32, 3, 3, 5)
|
||||
bias = zeros(Float32, 5)
|
||||
Conv(weight = weight,
|
||||
bias = bias,
|
||||
σ = sigmoid)
|
||||
```
|
||||
"""
|
||||
function Conv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
pad = calc_padding(pad, size(w)[1:N-2], dilation, stride)
|
||||
return Conv(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
function Conv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
|
||||
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
Conv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
"""
|
||||
convfilter(filter::Tuple, in=>out)
|
||||
|
||||
Constructs a standard convolutional weight matrix with given `filter` and
|
||||
channels from `in` to `out`.
|
||||
|
||||
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
|
||||
distribution.
|
||||
|
||||
See also: [`depthwiseconvfilter`](@ref)
|
||||
"""
|
||||
convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
|
||||
init = glorot_uniform) where N = init(filter..., ch...)
|
||||
|
||||
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
|
||||
|
||||
Conv(weight, bias, σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
@functor Conv
|
||||
@treelike Conv
|
||||
|
||||
function (c::Conv)(x::AbstractArray)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, ntuple(_->1, length(c.stride))..., :, 1)
|
||||
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||
σ.(conv(x, c.weight, cdims) .+ b)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Conv)
|
||||
@ -151,442 +58,96 @@ end
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
outdims(l::Conv, isize::Tuple)
|
||||
DepthwiseConv(size, in)
|
||||
DepthwiseConv(size, in=>mul)
|
||||
DepthwiseConv(size, in=>mul, relu)
|
||||
|
||||
Calculate the output dimensions given the input dimensions `isize`.
|
||||
Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl).
|
||||
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `mul` specify the number of input channels and channel multiplier respectively.
|
||||
In case the `mul` is not specified it is taken as 1.
|
||||
|
||||
```julia
|
||||
m = Conv((3, 3), 3 => 16)
|
||||
outdims(m, (10, 10)) == (8, 8)
|
||||
outdims(m, (10, 10, 1, 3)) == (8, 8)
|
||||
```
|
||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
outdims(l::Conv, isize) =
|
||||
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
ConvTranspose(filter, in=>out)
|
||||
ConvTranspose(filter, in=>out, activation)
|
||||
ConvTranspose(filter, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
|
||||
Standard convolutional transpose layer. `filter` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
|
||||
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == stride * inputsize - stride + 1.
|
||||
"""
|
||||
struct ConvTranspose{N,M,F,A,V}
|
||||
struct DepthwiseConv{N,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
end
|
||||
|
||||
"""
|
||||
ConvTranspose(weight::AbstractArray, bias::AbstractArray)
|
||||
ConvTranspose(weight::AbstractArray, bias::AbstractArray, activation)
|
||||
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0) where {T,N} =
|
||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
||||
|
||||
Constructs the convolutional transpose layer with user defined weight and bias arrays.
|
||||
forward pass.
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
|
||||
stride = 1, pad = 0) where N =
|
||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
@treelike DepthwiseConv
|
||||
|
||||
For keyword-only constuctor, see also [`Conv`](@ref)
|
||||
"""
|
||||
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
pad = calc_padding(pad, size(w)[1:N-2], dilation, stride)
|
||||
return ConvTranspose(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
|
||||
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
ConvTranspose(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
|
||||
|
||||
ConvTranspose(weight, bias, σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
@functor ConvTranspose
|
||||
|
||||
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
|
||||
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
|
||||
C_in = size(c.weight)[end-1]
|
||||
batch_size = size(x)[end]
|
||||
# Create DenseConvDims() that looks like the corresponding conv()
|
||||
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
||||
stride=c.stride,
|
||||
padding=c.pad,
|
||||
dilation=c.dilation,
|
||||
)
|
||||
end
|
||||
|
||||
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
|
||||
@nograd conv_transpose_dims
|
||||
|
||||
function (c::ConvTranspose)(x::AbstractArray)
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
function (c::DepthwiseConv)(x)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
cdims = conv_transpose_dims(c, x)
|
||||
σ.(∇conv_data(x, c.weight, cdims) .+ b)
|
||||
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::ConvTranspose)
|
||||
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
|
||||
function Base.show(io::IO, l::DepthwiseConv)
|
||||
print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2])
|
||||
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
|
||||
|
||||
"""
|
||||
DepthwiseConv(filter::Tuple, in=>out)
|
||||
DepthwiseConv(filter::Tuple, in=>out, activation)
|
||||
DepthwiseConv(filter, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
MaxPool(k)
|
||||
|
||||
Depthwise convolutional layer. `filter` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
Note that `out` must be an integer multiple of `in`.
|
||||
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
|
||||
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct DepthwiseConv{N,M,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
"""
|
||||
DepthwiseConv(weight::AbstractArray, bias::AbstractArray)
|
||||
DepthwiseConv(weight::AbstractArray, bias::AbstractArray, activation)
|
||||
|
||||
Constructs the `DepthwiseConv` layer with user defined weight and bias arrays.
|
||||
forward pass.
|
||||
|
||||
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
|
||||
For keyword-only constuctor, see also [`Conv`](@ref)
|
||||
"""
|
||||
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
pad = calc_padding(pad, size(w)[1:N-2], dilation, stride)
|
||||
return DepthwiseConv(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
|
||||
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
DepthwiseConv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
"""
|
||||
depthwiseconvfilter(filter::Tuple, in=>out)
|
||||
|
||||
Constructs a depthwise convolutional weight array defined by `filter` and channels
|
||||
from `in` to `out`.
|
||||
|
||||
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
|
||||
distribution.
|
||||
|
||||
See also: [`convfilter`](@ref)
|
||||
"""
|
||||
depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
|
||||
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
|
||||
|
||||
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||
weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N
|
||||
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
|
||||
|
||||
return DepthwiseConv(
|
||||
weight,
|
||||
bias,
|
||||
σ;
|
||||
stride = stride,
|
||||
pad = pad,
|
||||
dilation = dilation
|
||||
)
|
||||
end
|
||||
|
||||
@functor DepthwiseConv
|
||||
|
||||
function (c::DepthwiseConv)(x)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::DepthwiseConv)
|
||||
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
|
||||
print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
outdims(l::DepthwiseConv, isize) =
|
||||
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
CrossCor(filter, in=>out)
|
||||
CrossCor(filter, in=>out, activation)
|
||||
CrossCor(filter, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
|
||||
Standard cross convolutional layer. `filter` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
|
||||
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
|
||||
|
||||
# Examples
|
||||
|
||||
Apply a `CrossCor` layer to a 1-channel input using a 2×2 window filter size, giving us a
|
||||
16-channel output. Output is activated with ReLU.
|
||||
```julia
|
||||
filter = (2,2)
|
||||
in = 1
|
||||
out = 16
|
||||
CrossCor((2, 2), 1=>16, relu)
|
||||
```
|
||||
"""
|
||||
struct CrossCor{N,M,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
"""
|
||||
CrossCor(weight::AbstractArray, bias::AbstractArray)
|
||||
CrossCor(weight::AbstractArray, bias::AbstractArray, activation)
|
||||
|
||||
Constructs the standard cross convolutional layer with user defined weight and bias
|
||||
arrays.
|
||||
|
||||
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
|
||||
For keyword-only constuctor, see also [`Conv`](@ref)
|
||||
"""
|
||||
function CrossCor(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
pad = calc_padding(pad, size(w)[1:N-2], dilation, stride)
|
||||
return CrossCor(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
|
||||
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
CrossCor(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
|
||||
|
||||
CrossCor(weight, bias, σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
@functor CrossCor
|
||||
|
||||
function crosscor(x, w, ddims::DenseConvDims)
|
||||
ddims = DenseConvDims(ddims, F=true)
|
||||
return conv(x, w, ddims)
|
||||
end
|
||||
|
||||
function (c::CrossCor)(x::AbstractArray)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||
σ.(crosscor(x, c.weight, cdims) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::CrossCor)
|
||||
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
|
||||
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
outdims(l::CrossCor, isize) =
|
||||
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
GlobalMaxPool()
|
||||
|
||||
Global max pooling layer.
|
||||
|
||||
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
|
||||
by performing max pooling on the complete (w,h)-shaped feature maps.
|
||||
"""
|
||||
struct GlobalMaxPool end
|
||||
|
||||
function (g::GlobalMaxPool)(x)
|
||||
# Input size
|
||||
x_size = size(x)
|
||||
# Kernel size
|
||||
k = x_size[1:end-2]
|
||||
# Pooling dimensions
|
||||
pdims = PoolDims(x, k)
|
||||
|
||||
return maxpool(x, pdims)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, g::GlobalMaxPool)
|
||||
print(io, "GlobalMaxPool()")
|
||||
end
|
||||
|
||||
"""
|
||||
GlobalMeanPool()
|
||||
|
||||
Global mean pooling layer.
|
||||
|
||||
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
|
||||
by performing mean pooling on the complete (w,h)-shaped feature maps.
|
||||
"""
|
||||
struct GlobalMeanPool end
|
||||
|
||||
function (g::GlobalMeanPool)(x)
|
||||
# Input size
|
||||
x_size = size(x)
|
||||
# Kernel size
|
||||
k = x_size[1:end-2]
|
||||
# Pooling dimensions
|
||||
pdims = PoolDims(x, k)
|
||||
|
||||
return meanpool(x, pdims)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, g::GlobalMeanPool)
|
||||
print(io, "GlobalMeanPool()")
|
||||
end
|
||||
|
||||
"""
|
||||
MaxPool(k; pad = 0, stride = k)
|
||||
|
||||
Max pooling layer. `k` is the size of the window for each dimension of the input.
|
||||
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
|
||||
=======
|
||||
"""
|
||||
struct MaxPool{N,M}
|
||||
struct MaxPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||
stride = expand(Val(N), stride)
|
||||
pad = calc_padding(pad, k, 1, stride)
|
||||
return MaxPool(k, pad, stride)
|
||||
end
|
||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
|
||||
function (m::MaxPool)(x)
|
||||
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||
return maxpool(x, pdims)
|
||||
end
|
||||
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MaxPool)
|
||||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
||||
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
||||
|
||||
"""
|
||||
MeanPool(k; pad = 0, stride = k)
|
||||
MeanPool(k)
|
||||
|
||||
Mean pooling layer. `k` is the size of the window for each dimension of the input.
|
||||
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct MeanPool{N,M}
|
||||
struct MeanPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||
stride = expand(Val(N), stride)
|
||||
pad = calc_padding(pad, k, 1, stride)
|
||||
return MeanPool(k, pad, stride)
|
||||
end
|
||||
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
|
||||
function (m::MeanPool)(x)
|
||||
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||
return meanpool(x, pdims)
|
||||
end
|
||||
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MeanPool)
|
||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
||||
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
||||
|
@ -1,108 +1,54 @@
|
||||
istraining() = false
|
||||
"""
|
||||
testmode!(m)
|
||||
testmode!(m, false)
|
||||
|
||||
@adjoint istraining() = true, _ -> nothing
|
||||
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
|
||||
(or back to training mode with `false`).
|
||||
"""
|
||||
function testmode!(m, val::Bool=true)
|
||||
prefor(x -> _testmode!(x, val), m)
|
||||
return m
|
||||
end
|
||||
|
||||
_isactive(m) = isnothing(m.active) ? istraining() : m.active
|
||||
_testmode!(m, test) = nothing
|
||||
|
||||
_dropout_shape(s, ::Colon) = size(s)
|
||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||
"""
|
||||
Dropout(p)
|
||||
|
||||
A Dropout layer. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it
|
||||
reduces overfitting during training.
|
||||
|
||||
Does nothing to the input once in [`testmode!`](@ref).
|
||||
"""
|
||||
mutable struct Dropout{F}
|
||||
p::F
|
||||
active::Bool
|
||||
end
|
||||
|
||||
function Dropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
Dropout{typeof(p)}(p, true)
|
||||
end
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
"""
|
||||
dropout(x, p; dims = :)
|
||||
|
||||
The dropout function. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
|
||||
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
|
||||
This is used as a regularisation, i.e. it reduces overfitting during training.
|
||||
|
||||
See also the [`Dropout`](@ref) layer.
|
||||
"""
|
||||
dropout(x, p; dims = :) = x
|
||||
|
||||
@adjoint function dropout(x, p; dims = :)
|
||||
y = rand!(similar(x, _dropout_shape(x, dims)))
|
||||
y .= _dropout_kernel.(y, p, 1 - p)
|
||||
return x .* y, Δ -> (Δ .* y, nothing)
|
||||
end
|
||||
|
||||
"""
|
||||
Dropout(p, dims = :)
|
||||
|
||||
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
|
||||
|
||||
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
|
||||
"""
|
||||
mutable struct Dropout{F,D}
|
||||
p::F
|
||||
dims::D
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
# TODO: deprecate in v0.11
|
||||
Dropout(p, dims) = Dropout(p, dims, nothing)
|
||||
|
||||
function Dropout(p; dims = :)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
Dropout{typeof(p),typeof(dims)}(p, dims, nothing)
|
||||
end
|
||||
|
||||
function (a::Dropout)(x)
|
||||
_isactive(a) || return x
|
||||
return dropout(x, a.p; dims = a.dims)
|
||||
a.active || return x
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
return x .* y
|
||||
end
|
||||
|
||||
testmode!(m::Dropout, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
function Base.show(io::IO, d::Dropout)
|
||||
print(io, "Dropout(", d.p)
|
||||
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
AlphaDropout(p)
|
||||
|
||||
A dropout layer. Used in
|
||||
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
|
||||
The AlphaDropout layer ensures that mean and variance of activations
|
||||
remain the same as before.
|
||||
|
||||
Does nothing to the input once [`testmode!`](@ref) is true.
|
||||
"""
|
||||
mutable struct AlphaDropout{F}
|
||||
p::F
|
||||
active::Union{Bool, Nothing}
|
||||
function AlphaDropout(p, active = nothing)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
new{typeof(p)}(p, active)
|
||||
end
|
||||
end
|
||||
|
||||
function (a::AlphaDropout)(x)
|
||||
_isactive(a) || return x
|
||||
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||
α = eltype(x)(1.6732632423543772848170429916717)
|
||||
α1 = eltype(x)(-λ*α)
|
||||
noise = randn(eltype(x), size(x))
|
||||
x = @. x*(noise > (1 - a.p)) + α1 * (noise < (1 - a.p))
|
||||
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
|
||||
B = -A * α1 * (1 - a.p)
|
||||
x = @. A * x + B
|
||||
return x
|
||||
end
|
||||
|
||||
testmode!(m::AlphaDropout, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||
|
||||
"""
|
||||
LayerNorm(h::Integer)
|
||||
|
||||
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
||||
used with recurrent hidden states of size `h`. Normalises the mean and standard
|
||||
deviation of each input before applying a per-neuron gain/bias.
|
||||
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
|
||||
each input before applying a per-neuron gain/bias.
|
||||
"""
|
||||
struct LayerNorm{T}
|
||||
diag::Diagonal{T}
|
||||
@ -111,7 +57,7 @@ end
|
||||
LayerNorm(h::Integer) =
|
||||
LayerNorm(Diagonal(h))
|
||||
|
||||
@functor LayerNorm
|
||||
@treelike LayerNorm
|
||||
|
||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||
|
||||
@ -124,8 +70,8 @@ end
|
||||
initβ = zeros, initγ = ones,
|
||||
ϵ = 1e-8, momentum = .1)
|
||||
|
||||
[Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf) layer.
|
||||
`channels` should be the size of the channel dimension in your data (see below).
|
||||
Batch Normalization layer. The `channels` input should be the size of the
|
||||
channel dimension in your data (see below).
|
||||
|
||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||
@ -135,9 +81,10 @@ it's the usual channel dimension.)
|
||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
||||
|
||||
# Examples
|
||||
Example:
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
@ -155,262 +102,59 @@ mutable struct BatchNorm{F,V,W,N}
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
active::Bool
|
||||
end
|
||||
|
||||
# TODO: deprecate in v0.11
|
||||
BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) = BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||
|
||||
trainable(bn::BatchNorm) = (bn.β, bn.γ)
|
||||
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
γ, β = BN.γ, BN.β
|
||||
dims = length(size(x))
|
||||
channels = size(x, dims-1)
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||
m = div(prod(size(x)), channels)
|
||||
γ = reshape(BN.γ, affine_shape...)
|
||||
β = reshape(BN.β, affine_shape...)
|
||||
if !_isactive(BN)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = channels
|
||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||
|
||||
if !BN.active
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ² = reshape(BN.σ², affine_shape...)
|
||||
ϵ = BN.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||
ϵ = convert(T, BN.ϵ)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = BN.momentum
|
||||
S = eltype(BN.μ)
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :))
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :))
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
# Break this up with a temporary variable to fix GPU launch bug
|
||||
# https://github.com/FluxML/Flux.jl/issues/385
|
||||
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
|
||||
return λ.(temp .+ reshape(β, affine_shape...))
|
||||
end
|
||||
end
|
||||
|
||||
@functor BatchNorm
|
||||
children(BN::BatchNorm) =
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
testmode!(m::BatchNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||
|
||||
function Base.show(io::IO, l::BatchNorm)
|
||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
|
||||
mutable struct InstanceNorm{F,V,W,N}
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
# TODO: deprecate in v0.11
|
||||
"""
|
||||
InstanceNorm(channels::Integer, σ = identity;
|
||||
initβ = zeros, initγ = ones,
|
||||
ϵ = 1e-8, momentum = .1)
|
||||
|
||||
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
|
||||
`channels` should be the size of the channel dimension in your data (see below).
|
||||
|
||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||
it's the usual channel dimension.)
|
||||
|
||||
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
InstanceNorm(64, relu),
|
||||
Dense(64, 10),
|
||||
InstanceNorm(10),
|
||||
softmax)
|
||||
```
|
||||
"""
|
||||
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||
|
||||
trainable(in::InstanceNorm) = (in.β, in.γ)
|
||||
|
||||
function (in::InstanceNorm)(x)
|
||||
size(x, ndims(x)-1) == length(in.β) ||
|
||||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
ndims(x) > 2 ||
|
||||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||
# these are repeated later on depending on the batch size
|
||||
dims = length(size(x))
|
||||
c = size(x, dims-1)
|
||||
bs = size(x, dims)
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x))
|
||||
m = div(prod(size(x)), c*bs)
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !_isactive(in)
|
||||
μ = expand_inst(in.μ, affine_shape)
|
||||
σ² = expand_inst(in.σ², affine_shape)
|
||||
ϵ = in.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = convert(T, in.ϵ)
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
S = eltype(in.μ)
|
||||
# update moving mean/std
|
||||
mtm = in.momentum
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2)
|
||||
end
|
||||
|
||||
let λ = in.λ
|
||||
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
@functor InstanceNorm
|
||||
|
||||
testmode!(m::InstanceNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||
ϵ = 1f-5, momentum = 0.1f0)
|
||||
|
||||
[Group Normalization](https://arxiv.org/pdf/1803.08494.pdf) layer.
|
||||
This layer can outperform Batch Normalization and Instance Normalization.
|
||||
|
||||
`chs` is the number of channels, the channel dimension of your input.
|
||||
For an array of N dimensions, the `N-1`th index is the channel dimension.
|
||||
|
||||
`G` is the number of groups along which the statistics are computed.
|
||||
The number of channels must be an integer multiple of the number of groups.
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
GroupNorm(32,16))
|
||||
# 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||
```
|
||||
"""
|
||||
mutable struct GroupNorm{F,V,W,N,T}
|
||||
G::T # number of groups
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
# TODO: deprecate in v0.11
|
||||
GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) = GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum, nothing)
|
||||
|
||||
trainable(gn::GroupNorm) = (gn.β, gn.γ)
|
||||
|
||||
function(gn::GroupNorm)(x)
|
||||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
|
||||
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
|
||||
|
||||
dims = length(size(x))
|
||||
groups = gn.G
|
||||
channels = size(x, dims-1)
|
||||
batches = size(x,dims)
|
||||
channels_per_group = div(channels,groups)
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||
|
||||
# Output reshaped to (W,H...,C/G,G,N)
|
||||
μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1)
|
||||
|
||||
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||
γ = reshape(gn.γ, affine_shape...)
|
||||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !_isactive(gn)
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
ϵ = gn.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
og_shape = size(x)
|
||||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||
μ = mean(y, dims = axes)
|
||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||
|
||||
ϵ = convert(T, gn.ϵ)
|
||||
# update moving mean/std
|
||||
mtm = gn.momentum
|
||||
S = eltype(gn.μ)
|
||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2)
|
||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2)
|
||||
end
|
||||
|
||||
let λ = gn.λ
|
||||
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
|
||||
# Reshape x̂
|
||||
x̂ = reshape(x̂,og_shape)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
@functor GroupNorm
|
||||
|
||||
testmode!(m::GroupNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
@ -1,5 +1,5 @@
|
||||
gate(h, n) = (1:h) .+ h*(n-1)
|
||||
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
|
||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||
|
||||
# Stateful recurrence
|
||||
@ -12,16 +12,16 @@ in the background. `cell` should be a model of the form:
|
||||
|
||||
h, y = cell(h, x...)
|
||||
|
||||
For example, here's a recurrent network that keeps a running total of its inputs:
|
||||
For example, here's a recurrent network that keeps a running total of its inputs.
|
||||
|
||||
```julia
|
||||
accum(h, x) = (h + x, x)
|
||||
accum(h, x) = (h+x, x)
|
||||
rnn = Flux.Recur(accum, 0)
|
||||
rnn(2) # 2
|
||||
rnn(3) # 3
|
||||
rnn.state # 5
|
||||
rnn.(1:10) # apply to a sequence
|
||||
rnn.state # 60
|
||||
rnn(2) # 2
|
||||
rnn(3) # 3
|
||||
rnn.state # 5
|
||||
rnn.(1:10) # apply to a sequence
|
||||
rnn.state # 60
|
||||
```
|
||||
"""
|
||||
mutable struct Recur{T}
|
||||
@ -38,22 +38,36 @@ function (m::Recur)(xs...)
|
||||
return y
|
||||
end
|
||||
|
||||
@functor Recur cell, init
|
||||
@treelike Recur cell, init
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
"""
|
||||
truncate!(rnn)
|
||||
|
||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
||||
state is preserved. See also `reset!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = Tracker.data(rnn.state)
|
||||
"""
|
||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
||||
|
||||
"""
|
||||
reset!(rnn)
|
||||
|
||||
Reset the hidden state of a recurrent layer back to its original value.
|
||||
Reset the hidden state of a recurrent layer back to its original value. See also
|
||||
`truncate!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
|
||||
```julia
|
||||
rnn.state = hidden(rnn.cell)
|
||||
```
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = hidden(rnn.cell)
|
||||
"""
|
||||
reset!(m::Recur) = (m.state = m.init)
|
||||
reset!(m) = foreach(reset!, functor(m)[1])
|
||||
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
|
||||
|
||||
flip(f, xs) = reverse(f.(reverse(xs)))
|
||||
|
||||
@ -69,8 +83,8 @@ end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, init(out, in), init(out, out),
|
||||
init(out), zeros(out))
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zeros(out)), param(init(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
@ -80,7 +94,7 @@ end
|
||||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
@functor RNNCell
|
||||
@treelike RNNCell
|
||||
|
||||
function Base.show(io::IO, l::RNNCell)
|
||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||
@ -108,9 +122,9 @@ end
|
||||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
||||
zeros(out), zeros(out))
|
||||
cell.b[gate(out, 2)] .= 1
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(init(out)), param(init(out)))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
|
||||
@ -128,7 +142,7 @@ end
|
||||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
@functor LSTMCell
|
||||
@treelike LSTMCell
|
||||
|
||||
Base.show(io::IO, l::LSTMCell) =
|
||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||
@ -136,10 +150,10 @@ Base.show(io::IO, l::LSTMCell) =
|
||||
"""
|
||||
LSTM(in::Integer, out::Integer)
|
||||
|
||||
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
|
||||
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
|
||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
"""
|
||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
@ -154,8 +168,8 @@ mutable struct GRUCell{A,V}
|
||||
end
|
||||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(init(out * 3, in), init(out * 3, out),
|
||||
init(out * 3), zeros(out))
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(init(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
@ -169,7 +183,7 @@ end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
@functor GRUCell
|
||||
@treelike GRUCell
|
||||
|
||||
Base.show(io::IO, l::GRUCell) =
|
||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||
@ -177,10 +191,10 @@ Base.show(io::IO, l::GRUCell) =
|
||||
"""
|
||||
GRU(in::Integer, out::Integer)
|
||||
|
||||
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
|
||||
RNN but generally exhibits a longer memory span over sequences.
|
||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
"""
|
||||
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
||||
|
@ -1,296 +1,53 @@
|
||||
using NNlib: logsoftmax, logσ
|
||||
|
||||
# Cost functions
|
||||
"""
|
||||
mae(ŷ, y)
|
||||
|
||||
Return the mean of absolute error; calculated as
|
||||
`sum(abs.(ŷ .- y)) / length(y)`.
|
||||
"""
|
||||
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
|
||||
|
||||
"""
|
||||
mse(ŷ, y)
|
||||
|
||||
Return the mean squared error between ŷ and y; calculated as
|
||||
`sum((ŷ .- y).^2) / length(y)`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.mse([0, 2], [1, 1])
|
||||
1//1
|
||||
```
|
||||
"""
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||
|
||||
|
||||
"""
|
||||
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
||||
|
||||
Return the mean of the squared logarithmic errors; calculated as
|
||||
`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
||||
The `ϵ` term provides numerical stability.
|
||||
|
||||
Penalizes an under-predicted estimate greater than an over-predicted estimate.
|
||||
"""
|
||||
msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
huber_loss(ŷ, y; δ=1.0)
|
||||
|
||||
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
|
||||
given the prediction `ŷ` and true values `y`.
|
||||
|
||||
| 0.5 * |ŷ - y|, for |ŷ - y| <= δ
|
||||
Huber loss = |
|
||||
| δ * (|ŷ - y| - 0.5 * δ), otherwise
|
||||
"""
|
||||
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
|
||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||
abs_error = abs.(ŷ .- y)
|
||||
temp = Zygote.dropgrad(abs_error .< δ)
|
||||
x = eltype(ŷ)(0.5)
|
||||
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing)
|
||||
return -sum(xlogy.(y, ŷ)) * 1 // size(y, 2)
|
||||
end
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Number)
|
||||
return -sum(xlogy.(y, ŷ)) .* weight * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::AbstractVector)
|
||||
return -sum(xlogy.(y, ŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
"""
|
||||
crossentropy(ŷ, y; weight = nothing)
|
||||
|
||||
Return the cross entropy between the given probability distributions;
|
||||
calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
||||
|
||||
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
|
||||
`weight=nothing` acts like `weight=1` but is faster.
|
||||
|
||||
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.crossentropy(softmax([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
|
||||
3.085467254747739
|
||||
```
|
||||
"""
|
||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight)
|
||||
|
||||
"""
|
||||
logitcrossentropy(ŷ, y; weight = 1)
|
||||
|
||||
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
|
||||
calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
|
||||
|
||||
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
|
||||
[`Flux.crossentropy(softmax(ŷ), y)`](@ref) but it is more numerically stable.
|
||||
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.logitcrossentropy([-1.1491, 0.8619, 0.3127], [1, 1, 0])
|
||||
3.085467254747738
|
||||
```
|
||||
"""
|
||||
function logitcrossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(ŷ) .* weight) * 1 // size(y, 2)
|
||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
"""
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
||||
|
||||
Return ``-y*\\log(ŷ + ϵ) - (1-y)*\\log(1-ŷ + ϵ)``. The `ϵ` term provides numerical stability.
|
||||
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
||||
|
||||
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
|
||||
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
|
||||
3-element Array{Float64,1}:
|
||||
1.424397097347566
|
||||
0.35231664672364077
|
||||
0.8616703662235441
|
||||
```
|
||||
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
||||
3-element Array{Float64,1}:
|
||||
1.4244
|
||||
0.352317
|
||||
0.86167
|
||||
"""
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ)
|
||||
|
||||
# Re-definition to fix interaction with CuArrays.
|
||||
CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
"""
|
||||
logitbinarycrossentropy(ŷ, y)
|
||||
logitbinarycrossentropy(logŷ, y)
|
||||
|
||||
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
|
||||
[`Flux.binarycrossentropy(σ(ŷ), y)`](@ref) but it is more numerically stable.
|
||||
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
|
||||
but it is more numerically stable.
|
||||
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0])
|
||||
3-element Array{Float64,1}:
|
||||
1.4243970973475661
|
||||
0.35231664672364094
|
||||
0.8616703662235443
|
||||
```
|
||||
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
|
||||
3-element Array{Float64,1}:
|
||||
1.4244
|
||||
0.352317
|
||||
0.86167
|
||||
"""
|
||||
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||
|
||||
# Re-definition to fix interaction with CuArrays.
|
||||
CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
|
||||
"""
|
||||
normalise(x; dims=1)
|
||||
normalise(x::AbstractVecOrMat)
|
||||
|
||||
Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`.
|
||||
Defaults to normalising over columns.
|
||||
|
||||
```jldoctest
|
||||
julia> a = reshape(collect(1:9), 3, 3)
|
||||
3×3 Array{Int64,2}:
|
||||
1 4 7
|
||||
2 5 8
|
||||
3 6 9
|
||||
|
||||
julia> Flux.normalise(a)
|
||||
3×3 Array{Float64,2}:
|
||||
-1.22474 -1.22474 -1.22474
|
||||
0.0 0.0 0.0
|
||||
1.22474 1.22474 1.22474
|
||||
|
||||
julia> Flux.normalise(a, dims=2)
|
||||
3×3 Array{Float64,2}:
|
||||
-1.22474 0.0 1.22474
|
||||
-1.22474 0.0 1.22474
|
||||
-1.22474 0.0 1.22474
|
||||
```
|
||||
Normalise each column of `x` to mean 0 and standard deviation 1.
|
||||
"""
|
||||
function normalise(x::AbstractArray; dims=1)
|
||||
μ′ = mean(x, dims = dims)
|
||||
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
||||
function normalise(x::AbstractVecOrMat)
|
||||
μ′ = mean(x, dims = 1)
|
||||
σ′ = std(x, dims = 1, mean = μ′)
|
||||
return (x .- μ′) ./ σ′
|
||||
end
|
||||
|
||||
"""
|
||||
kldivergence(ŷ, y)
|
||||
|
||||
Return the
|
||||
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
|
||||
between the given probability distributions.
|
||||
|
||||
KL divergence is a measure of how much one probability distribution is different
|
||||
from the other.
|
||||
It is always non-negative and zero only when both the distributions are equal
|
||||
everywhere.
|
||||
"""
|
||||
function kldivergence(ŷ, y)
|
||||
entropy = sum(xlogx.(y)) * 1 //size(y,2)
|
||||
cross_entropy = crossentropy(ŷ, y)
|
||||
return entropy + cross_entropy
|
||||
end
|
||||
|
||||
"""
|
||||
poisson(ŷ, y)
|
||||
|
||||
Return how much the predicted distribution `ŷ` diverges from the expected Poisson
|
||||
distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
|
||||
|
||||
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||
"""
|
||||
poisson(ŷ, y) = sum(ŷ .- xlogy.(y, ŷ)) * 1 // size(y,2)
|
||||
|
||||
"""
|
||||
hinge(ŷ, y)
|
||||
|
||||
Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
|
||||
prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
|
||||
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
|
||||
|
||||
See also: [`squared_hinge`](@ref)
|
||||
"""
|
||||
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
||||
|
||||
"""
|
||||
squared_hinge(ŷ, y)
|
||||
|
||||
Return the squared hinge loss given the prediction `ŷ` and true labels `y`
|
||||
(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`.
|
||||
|
||||
See also: [`hinge`](@ref)
|
||||
"""
|
||||
squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2)
|
||||
|
||||
"""
|
||||
dice_coeff_loss(ŷ, y; smooth=1)
|
||||
|
||||
Return a loss based on the dice coefficient.
|
||||
Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation
|
||||
architecture.
|
||||
Similar to the F1_score. Calculated as:
|
||||
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
|
||||
"""
|
||||
dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth)
|
||||
|
||||
"""
|
||||
tversky_loss(ŷ, y; β=0.7)
|
||||
|
||||
Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf).
|
||||
Used with imbalanced data to give more weight to false negatives.
|
||||
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
||||
Calculated as:
|
||||
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||
"""
|
||||
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||
|
||||
"""
|
||||
flatten(x::AbstractArray)
|
||||
|
||||
Transform (w, h, c, b)-shaped input into (w × h × c, b)-shaped output
|
||||
by linearizing all values for each element in the batch.
|
||||
"""
|
||||
function flatten(x::AbstractArray)
|
||||
return reshape(x, :, size(x)[end])
|
||||
end
|
||||
|
||||
"""
|
||||
xlogx(x)
|
||||
Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit.
|
||||
"""
|
||||
function xlogx(x)
|
||||
result = x * log(x)
|
||||
ifelse(iszero(x), zero(result), result)
|
||||
end
|
||||
CuArrays.@cufunc function xlogx(x)
|
||||
result = x * log(x)
|
||||
ifelse(iszero(x), zero(result), result)
|
||||
end
|
||||
|
||||
"""
|
||||
xlogy(x, y)
|
||||
Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
|
||||
"""
|
||||
function xlogy(x, y)
|
||||
result = x * log(y)
|
||||
ifelse(iszero(x), zero(result), result)
|
||||
end
|
||||
CuArrays.@cufunc function xlogy(x, y)
|
||||
result = x * log(y)
|
||||
ifelse(iszero(x), zero(result), result)
|
||||
end
|
||||
|
||||
@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric)
|
||||
res = xlogy.(x, y)
|
||||
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
|
||||
end
|
||||
|
@ -9,8 +9,6 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
||||
|
||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
|
||||
Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)
|
||||
|
||||
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||
@ -20,15 +18,11 @@ end
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
||||
|
||||
# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))]
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
|
||||
@ -38,34 +32,13 @@ import Adapt: adapt, adapt_structure
|
||||
|
||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
import .CuArrays: CuArray, CuArrayStyle, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
"""
|
||||
onehot(l, labels[, unk])
|
||||
|
||||
Create a `OneHotVector` with its `l`-th element `true` based on the
|
||||
possible set of `labels`.
|
||||
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
|
||||
in `labels`; otherwise, it will raise an error.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
1
|
||||
0
|
||||
|
||||
julia> Flux.onehot(:c, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
0
|
||||
1
|
||||
```
|
||||
"""
|
||||
function onehot(l, labels)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
@ -78,48 +51,20 @@ function onehot(l, labels, unk)
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
"""
|
||||
onehotbatch(ls, labels[, unk...])
|
||||
|
||||
Create a `OneHotMatrix` with a batch of labels based on the
|
||||
possible set of `labels`.
|
||||
If `unk` is given, return [`onehot(unk, labels)`](@ref) if one of the input
|
||||
labels `ls` is not found in `labels`; otherwise it will error.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
|
||||
0 1 0
|
||||
1 0 1
|
||||
0 0 0
|
||||
```
|
||||
"""
|
||||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
|
||||
Base.argmax(xs::OneHotVector) = xs.ix
|
||||
|
||||
"""
|
||||
onecold(y[, labels = 1:length(y)])
|
||||
|
||||
Inverse operations of [`onehot`](@ref).
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.onecold([true, false, false], [:a, :b, :c])
|
||||
:a
|
||||
|
||||
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
:c
|
||||
```
|
||||
"""
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
onecold(y::AbstractMatrix, labels...) =
|
||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||
|
||||
onecold(y::OneHotMatrix, labels...) =
|
||||
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
|
||||
function argmax(xs...)
|
||||
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
return onecold(xs...)
|
||||
end
|
||||
|
||||
@nograd onecold, onehot, onehotbatch
|
||||
# Ambiguity hack
|
||||
|
||||
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
|
||||
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
|
||||
|
@ -1,14 +1,12 @@
|
||||
module Optimise
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
export train!, update!,
|
||||
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser,
|
||||
ClipValue, ClipNorm
|
||||
export train!,
|
||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
include("deprecations.jl")
|
||||
|
||||
end
|
||||
|
126
src/optimise/deprecations.jl
Normal file
126
src/optimise/deprecations.jl
Normal file
@ -0,0 +1,126 @@
|
||||
using Base: depwarn
|
||||
using Flux: Params
|
||||
|
||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||
|
||||
# legacy update rule
|
||||
updaterule(opt, ps) = () -> update!(opt, ps)
|
||||
|
||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||
|
||||
ps = params
|
||||
opt = Descent(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
|
||||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||
|
||||
ps = params
|
||||
opt = Momentum(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||
|
||||
ps = params
|
||||
opt = Nesterov(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||
|
||||
ps = params
|
||||
opt = RMSProp(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
|
||||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||
|
||||
ps = params
|
||||
opt = ADAGrad(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
|
||||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||
|
||||
ps = params
|
||||
opt = ADADelta(ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AdaMax(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AMSGrad(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = NADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAMW(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
# Old training loop
|
||||
|
||||
struct OldOptimiser
|
||||
func
|
||||
end
|
||||
|
||||
update!(opt::OldOptimiser, ps) = opt.func()
|
||||
|
||||
# Train function
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
||||
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
||||
end
|
@ -1,4 +1,5 @@
|
||||
using Flux
|
||||
using Base: @get!
|
||||
using MacroTools: @forward
|
||||
|
||||
const ϵ = 1e-8
|
||||
@ -6,29 +7,10 @@ const ϵ = 1e-8
|
||||
# TODO: should use weak refs
|
||||
|
||||
"""
|
||||
Descent(η = 0.1)
|
||||
Descent(η)
|
||||
|
||||
Classic gradient descent optimiser with learning rate `η`.
|
||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = Descent()
|
||||
|
||||
opt = Descent(0.3)
|
||||
|
||||
ps = params(model)
|
||||
|
||||
gs = gradient(ps) do
|
||||
loss(x, y)
|
||||
end
|
||||
|
||||
Flux.Optimise.update!(opt, ps, gs)
|
||||
```
|
||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
||||
"""
|
||||
mutable struct Descent
|
||||
eta::Float64
|
||||
@ -36,27 +18,14 @@ end
|
||||
|
||||
Descent() = Descent(0.1)
|
||||
|
||||
function apply!(o::Descent, x, Δ)
|
||||
function update!(o::Descent, x, Δ)
|
||||
Δ .*= o.eta
|
||||
end
|
||||
|
||||
"""
|
||||
Momentum(η = 0.01, ρ = 0.9)
|
||||
Momentum(params, η = 0.01; ρ = 0.9)
|
||||
|
||||
Gradient descent optimizer with learning rate `η` and momentum `ρ`.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
|
||||
prominent direction, in effect dampening oscillations.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = Momentum()
|
||||
|
||||
opt = Momentum(0.01, 0.99)
|
||||
```
|
||||
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||
"""
|
||||
mutable struct Momentum
|
||||
eta::Float64
|
||||
@ -66,7 +35,7 @@ end
|
||||
|
||||
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
function update!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
@ -74,22 +43,9 @@ function apply!(o::Momentum, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
Nesterov(η = 0.001, ρ = 0.9)
|
||||
Nesterov(eta, ρ = 0.9)
|
||||
|
||||
Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the
|
||||
prominent direction, in effect dampening oscillations.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = Nesterov()
|
||||
|
||||
opt = Nesterov(0.003, 0.95)
|
||||
```
|
||||
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
||||
"""
|
||||
mutable struct Nesterov
|
||||
eta::Float64
|
||||
@ -99,7 +55,7 @@ end
|
||||
|
||||
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
function update!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@ -110,23 +66,9 @@ end
|
||||
"""
|
||||
RMSProp(η = 0.001, ρ = 0.9)
|
||||
|
||||
Optimizer using the
|
||||
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
algorithm. Often a good choice for recurrent networks. Parameters other than learning rate
|
||||
generally don't need tuning.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
|
||||
prominent direction, in effect dampening oscillations.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = RMSProp()
|
||||
|
||||
opt = RMSProp(0.002, 0.95)
|
||||
```
|
||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||
choice for recurrent networks.
|
||||
"""
|
||||
mutable struct RMSProp
|
||||
eta::Float64
|
||||
@ -136,7 +78,7 @@ end
|
||||
|
||||
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
function update!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@ -144,22 +86,9 @@ function apply!(o::RMSProp, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
ADAM(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADAM()
|
||||
|
||||
opt = ADAM(0.001, (0.9, 0.8))
|
||||
```
|
||||
"""
|
||||
mutable struct ADAM
|
||||
eta::Float64
|
||||
@ -169,7 +98,7 @@ end
|
||||
|
||||
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
||||
|
||||
function apply!(o::ADAM, x, Δ)
|
||||
function update!(o::ADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@ -180,65 +109,10 @@ function apply!(o::ADAM, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
RADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||
|
||||
[Rectified ADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimizer.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = RADAM()
|
||||
|
||||
opt = RADAM(0.001, (0.9, 0.8))
|
||||
```
|
||||
"""
|
||||
mutable struct RADAM
|
||||
eta::Float64
|
||||
beta::Tuple{Float64,Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
|
||||
|
||||
function apply!(o::RADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
ρ∞ = 2/(1-β[2])-1
|
||||
mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
ρ = ρ∞ - 2t*βp[2]/(1-βp[2])
|
||||
if ρ > 4
|
||||
r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
|
||||
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r
|
||||
else
|
||||
@. Δ = mt / (1 - βp[1]) * η
|
||||
end
|
||||
o.state[x] = (mt, vt, βp .* β, t+1)
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
AdaMax(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) is a variant of ADAM based on the ∞-norm.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = AdaMax()
|
||||
|
||||
opt = AdaMax(0.001, (0.9, 0.995))
|
||||
```
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
the ∞-norm.
|
||||
"""
|
||||
mutable struct AdaMax
|
||||
eta::Float64
|
||||
@ -248,7 +122,7 @@ end
|
||||
|
||||
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
||||
|
||||
function apply!(o::AdaMax, x, Δ)
|
||||
function update!(o::AdaMax, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@ -259,22 +133,10 @@ function apply!(o::AdaMax, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAGrad(η = 0.1)
|
||||
ADAGrad(η = 0.1; ϵ = 1e-8)
|
||||
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
|
||||
parameter specific learning rates based on how frequently it is updated.
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
Parameters don't need tuning.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADAGrad()
|
||||
|
||||
opt = ADAGrad(0.001)
|
||||
```
|
||||
"""
|
||||
mutable struct ADAGrad
|
||||
eta::Float64
|
||||
@ -283,29 +145,18 @@ end
|
||||
|
||||
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
function update!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill!(zero(x), ϵ))::typeof(x)
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADADelta(ρ = 0.9)
|
||||
ADADelta(ρ = 0.9, ϵ = 1e-8)
|
||||
|
||||
[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning
|
||||
rate based on a window of past gradient updates.
|
||||
Parameters don't need tuning.
|
||||
|
||||
# Parameters
|
||||
- Rho (`ρ`): Factor by which the gradient is decayed at each time step.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADADelta()
|
||||
|
||||
opt = ADADelta(0.89)
|
||||
```
|
||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct ADADelta
|
||||
rho::Float64
|
||||
@ -314,7 +165,7 @@ end
|
||||
|
||||
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
||||
|
||||
function apply!(o::ADADelta, x, Δ)
|
||||
function update!(o::ADADelta, x, Δ)
|
||||
ρ = o.rho
|
||||
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@ -324,23 +175,10 @@ function apply!(o::ADADelta, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM
|
||||
optimiser. Parameters don't need tuning.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = AMSGrad()
|
||||
|
||||
opt = AMSGrad(0.001, (0.89, 0.995))
|
||||
```
|
||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct AMSGrad
|
||||
eta::Float64
|
||||
@ -350,33 +188,20 @@ end
|
||||
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||
|
||||
function apply!(o::AMSGrad, x, Δ)
|
||||
function update!(o::AMSGrad, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, v̂t = get!(o.state, x, (fill!(zero(x), ϵ), fill!(zero(x), ϵ), fill!(zero(x), ϵ)))
|
||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
||||
@. v̂t = max(v̂t, vt)
|
||||
@. v̂t = max.(v̂t, vt)
|
||||
@. Δ = η * mt / (√v̂t + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
NADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
NADAM(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) is a Nesterov variant of ADAM.
|
||||
Parameters don't need tuning.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = NADAM()
|
||||
|
||||
opt = NADAM(0.002, (0.89, 0.995))
|
||||
```
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct NADAM
|
||||
eta::Float64
|
||||
@ -386,9 +211,10 @@ end
|
||||
|
||||
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||
|
||||
function apply!(o::NADAM, x, Δ)
|
||||
function update!(o::NADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, (β1p, β2p) = get!(o.state, x, (zero(x), zero(x), o.beta))
|
||||
β1p, β2p = o.beta
|
||||
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η
|
||||
@ -397,27 +223,12 @@ function apply!(o::NADAM, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0)
|
||||
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its
|
||||
weight decay regularization.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
- `decay`: Decay applied to weights during optimisation.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADAMW()
|
||||
|
||||
opt = ADAMW(0.001, (0.89, 0.995), 0.1)
|
||||
```
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||
Optimiser(ADAM(η, β), WeightDecay(decay))
|
||||
Optimiser(ADAM(η, β), WeightDecay(wd))
|
||||
|
||||
# Compose optimizers
|
||||
|
||||
@ -439,23 +250,19 @@ Optimiser(o...) = Optimiser(Any[o...])
|
||||
|
||||
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
||||
|
||||
function apply!(o::Optimiser, x, Δ)
|
||||
function update!(o::Optimiser, x, Δ)
|
||||
for opt in o.os
|
||||
Δ = apply!(opt, x, Δ)
|
||||
Δ = update!(opt, x, Δ)
|
||||
end
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
InvDecay(γ = 0.001)
|
||||
`InvDecay(γ)`
|
||||
|
||||
Apply inverse time decay to an optimiser, so that the effective step size at
|
||||
iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size.
|
||||
The wrapped optimiser's step size is not modified.
|
||||
|
||||
# Examples
|
||||
Apply inverse time decay to an optimiser
|
||||
```julia
|
||||
Optimiser(InvDecay(..), Opt(..))
|
||||
Optimiser(InvDecay(..), Opt(..))
|
||||
```
|
||||
"""
|
||||
mutable struct InvDecay
|
||||
@ -465,7 +272,7 @@ end
|
||||
|
||||
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||
|
||||
function apply!(o::InvDecay, x, Δ)
|
||||
function update!(o::InvDecay, x, Δ)
|
||||
γ = o.gamma
|
||||
n = get!(o.state, x, 1)
|
||||
Δ .*= 1 / (1 + γ * n)
|
||||
@ -474,25 +281,13 @@ function apply!(o::InvDecay, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4)
|
||||
`ExpDecay(eta, decay, decay_step, clip)`
|
||||
|
||||
Discount the learning rate `η` by the factor `decay` every `decay_step` steps till
|
||||
a minimum of `clip`.
|
||||
Schedule the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
|
||||
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||
the weights.
|
||||
- `decay`: Factor by which the learning rate is discounted.
|
||||
- `decay_step`: Schedule decay operations by setting the number of steps between
|
||||
two decay operations.
|
||||
- `clip`: Minimum value of learning rate.
|
||||
|
||||
# Examples
|
||||
To apply exponential decay to an optimiser:
|
||||
```julia
|
||||
Optimiser(ExpDecay(..), Opt(..))
|
||||
|
||||
opt = Optimiser(ExpDecay(), ADAM())
|
||||
Optimiser(ExpDecay(..), Opt(..))
|
||||
```
|
||||
"""
|
||||
mutable struct ExpDecay
|
||||
@ -505,23 +300,20 @@ end
|
||||
|
||||
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
|
||||
function apply!(o::ExpDecay, x, Δ)
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
η, s, decay = o.eta, o.step, o.decay
|
||||
n = o.current[x] = get(o.current, x, 0) + 1
|
||||
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
||||
η = max(η * decay, o.clip)
|
||||
η = max(η * decay^(s / n), o.clip)
|
||||
o.eta = η
|
||||
end
|
||||
@. Δ *= η
|
||||
@. Δ *= decay
|
||||
end
|
||||
|
||||
"""
|
||||
WeightDecay(wd = 0)
|
||||
`WeightDecay(wd)`
|
||||
|
||||
Decay weights by `wd`.
|
||||
|
||||
# Parameters
|
||||
- Weight decay (`wd`)
|
||||
Decay the weight parameter by `wd`
|
||||
"""
|
||||
mutable struct WeightDecay
|
||||
wd::Real
|
||||
@ -529,35 +321,7 @@ end
|
||||
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
function update!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
||||
"""
|
||||
ClipValue(thresh)
|
||||
|
||||
Clip gradients when their absolute value exceeds `thresh`.
|
||||
"""
|
||||
mutable struct ClipValue{T}
|
||||
thresh::T
|
||||
end
|
||||
|
||||
apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
|
||||
|
||||
"""
|
||||
ClipNorm(thresh)
|
||||
|
||||
Clip gradients when their L2 norm exceeds `thresh`.
|
||||
"""
|
||||
mutable struct ClipNorm{T}
|
||||
thresh::T
|
||||
end
|
||||
|
||||
function apply!(o::ClipNorm, x, Δ)
|
||||
Δnrm = norm(Δ)
|
||||
if Δnrm > o.thresh
|
||||
rmul!(Δ, o.thresh / Δnrm)
|
||||
end
|
||||
return Δ
|
||||
end
|
@ -1,34 +1,12 @@
|
||||
using Juno
|
||||
import Zygote: Params, gradient
|
||||
using Flux.Tracker: data, grad, back!
|
||||
import Base.depwarn
|
||||
|
||||
|
||||
|
||||
"""
|
||||
update!(x, x̄)
|
||||
|
||||
Update the array `x` according to `x .-= x̄`.
|
||||
"""
|
||||
function update!(x::AbstractArray, x̄)
|
||||
x .-= x̄
|
||||
end
|
||||
|
||||
"""
|
||||
update!(opt, p, g)
|
||||
update!(opt, ps::Params, gs)
|
||||
|
||||
Perform an update step of the parameters `ps` (or the single parameter `p`)
|
||||
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
|
||||
|
||||
As a result, the parameters are mutated and the optimizer's internal state may change.
|
||||
"""
|
||||
function update!(opt, x, x̄)
|
||||
x .-= apply!(opt, x, x̄)
|
||||
end
|
||||
|
||||
function update!(opt, xs::Params, gs)
|
||||
function update!(opt, xs)
|
||||
for x in xs
|
||||
gs[x] == nothing && continue
|
||||
update!(opt, x, gs[x])
|
||||
Δ = update!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
end
|
||||
end
|
||||
|
||||
@ -37,16 +15,26 @@ call(f, xs...) = f(xs...)
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
struct StopException <: Exception end
|
||||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
struct StopException <: Exception end
|
||||
"""
|
||||
stop()
|
||||
|
||||
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||
This will trigger the train loop to stop and exit.
|
||||
This would trigger the train loop to stop and exit.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
# Example callback:
|
||||
|
||||
cb = function ()
|
||||
accuracy() > 0.9 && Flux.stop()
|
||||
end
|
||||
@ -59,37 +47,33 @@ end
|
||||
"""
|
||||
train!(loss, params, data, opt; cb)
|
||||
|
||||
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
|
||||
backpropagation and call the optimizer `opt`.
|
||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||
backpropagation and calls the optimizer `opt`.
|
||||
|
||||
In case datapoints `d` are of numeric array type, assume no splatting is needed
|
||||
and compute the gradient of `loss(d)`.
|
||||
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
||||
every 10 seconds:
|
||||
|
||||
A callback is given with the keyword argument `cb`. For example, this will print
|
||||
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
||||
```julia
|
||||
Flux.train!(loss, params, data, opt,
|
||||
cb = throttle(() -> println("training"), 10))
|
||||
```
|
||||
|
||||
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
|
||||
|
||||
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
|
||||
The callback can call `Flux.stop()` to interrupt the training loop.
|
||||
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
function train!(loss, ps, data, opt; cb = () -> ())
|
||||
ps = Params(ps)
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
try
|
||||
if d isa AbstractArray{<:Number}
|
||||
gs = gradient(ps) do
|
||||
loss(d)
|
||||
end
|
||||
else
|
||||
gs = gradient(ps) do
|
||||
loss(d...)
|
||||
end
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
update!(opt, ps)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
cb()
|
||||
catch ex
|
||||
if ex isa StopException
|
||||
break
|
||||
@ -106,12 +90,11 @@ end
|
||||
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
||||
training in a REPL.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.@epochs 2 println("hello")
|
||||
[ Info: Epoch 1
|
||||
```julia
|
||||
julia> @epochs 2 println("hello")
|
||||
INFO: Epoch 1
|
||||
hello
|
||||
[ Info: Epoch 2
|
||||
INFO: Epoch 2
|
||||
hello
|
||||
```
|
||||
"""
|
||||
|
111
src/tracker/Tracker.jl
Normal file
111
src/tracker/Tracker.jl
Normal file
@ -0,0 +1,111 @@
|
||||
module Tracker
|
||||
|
||||
using MacroTools
|
||||
using MacroTools: @q, @forward
|
||||
|
||||
import Base: ==
|
||||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
||||
param, back!
|
||||
|
||||
tracker(x) = nothing
|
||||
|
||||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
grad(x) = grad(tracker(x))
|
||||
grad(::Nothing) = nothing
|
||||
data(x) = x
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
args::As
|
||||
end
|
||||
|
||||
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
||||
Call() = Call(nothing, ())
|
||||
|
||||
# When deserialising, the object_id changes
|
||||
a::Call == b::Call = a.func == b.func && a.args == b.args
|
||||
|
||||
@inline (c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
isleaf::Bool
|
||||
grad::T
|
||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||
end
|
||||
|
||||
istracked(x::Tracked) = true
|
||||
isleaf(x::Tracked) = x.f == Call()
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||
|
||||
function _forward end
|
||||
|
||||
function track(f::F, xs...; kw...) where F
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
||||
macro grad(ex)
|
||||
@capture(shortdef(ex), (name_(args__) = body_) |
|
||||
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
||||
T == nothing && (T = [])
|
||||
isexpr(name, :(::)) || (name = :(::typeof($name)))
|
||||
insert!(args, 1+isexpr(args[1], :parameters) , name)
|
||||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("back.jl")
|
||||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
include("lib/array.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`."""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||
|
||||
"""
|
||||
checkpoint(f, args...)
|
||||
|
||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||
backward pass. This can be used to save memory in larger models.
|
||||
"""
|
||||
checkpoint(f, args...) = track(checkpoint, f, args...)
|
||||
|
||||
@grad function checkpoint(f, args...)
|
||||
data(f(args...)), function (Δ)
|
||||
y, back = forward(f, args...)
|
||||
(nothing, back(Δ)...)
|
||||
end
|
||||
end
|
||||
|
||||
nobacksies(f, x) = track(nobacksies, f, x)
|
||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
@grad identity(x) = data(x), Δ -> (Δ,)
|
||||
param(x::TrackedReal) = track(identity, x)
|
||||
param(x::TrackedArray) = track(identity, x)
|
||||
|
||||
import Adapt: adapt, adapt_structure
|
||||
|
||||
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
181
src/tracker/back.jl
Normal file
181
src/tracker/back.jl
Normal file
@ -0,0 +1,181 @@
|
||||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::Tracked)
|
||||
x.isleaf && return
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
function scan(x)
|
||||
istracked(x) && scan(tracker(x))
|
||||
return
|
||||
end
|
||||
|
||||
function back_(c::Call, Δ, once)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Nothing}, Δ, once) = nothing
|
||||
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
grad = if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
elseif ref > 0
|
||||
x.grad = Δ
|
||||
else
|
||||
Δ
|
||||
end
|
||||
if ref == 0
|
||||
back_(x.f, grad, once)
|
||||
once && !x.isleaf && (x.f = Call(missing, ()))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Nothing, Δ, once) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
# (but only if you re-use intermediate values between passes)
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x, Δ; once = true)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
back(tracker(x), Δ, once)
|
||||
return
|
||||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
||||
grad.(xs))
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
struct Grads
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
x.isleaf && (accum!(g, x, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if ref > 0 || haskey(g, x)
|
||||
accum!(g, x, Δ)
|
||||
ref == 0 && back_(g, x.f, g[x])
|
||||
else
|
||||
ref == 0 && back_(g, x.f, Δ)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Grads, ::Nothing, _) = return
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
y, function (Δ)
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
back(g, tracker(y), Δ)
|
||||
end
|
||||
return g
|
||||
end
|
||||
end
|
||||
|
||||
function forward(f, args...)
|
||||
args = param.(args)
|
||||
y, back = forward(() -> f(args...), Params(args))
|
||||
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||
end
|
||||
|
||||
function losscheck(x)
|
||||
x isa Real || error("Function output is not scalar")
|
||||
isinf(x) && error("Loss is infinite")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
end
|
||||
|
||||
function gradient_nested(f, args...)
|
||||
y, back = forward(f, args...)
|
||||
losscheck(y)
|
||||
return back(1)
|
||||
end
|
||||
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
28
src/tracker/idset.jl
Normal file
28
src/tracker/idset.jl
Normal file
@ -0,0 +1,28 @@
|
||||
struct IdSet{T} <: AbstractSet{T}
|
||||
dict::IdDict{T,Nothing}
|
||||
IdSet{T}() where T = new(IdDict{T,Nothing}())
|
||||
end
|
||||
|
||||
Base.eltype(::IdSet{T}) where T = T
|
||||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!(s::IdSet) = s
|
||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
||||
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
|
||||
|
||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||
|
||||
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
|
||||
Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
||||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
function Base.iterate(v::IdSet, state...)
|
||||
y = Base.iterate(keys(v.dict), state...)
|
||||
y === nothing && return nothing
|
||||
return (y[1], y[2])
|
||||
end
|
496
src/tracker/lib/array.jl
Normal file
496
src/tracker/lib/array.jl
Normal file
@ -0,0 +1,496 @@
|
||||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
import LinearAlgebra: inv, \, /
|
||||
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||
end
|
||||
|
||||
data(x::TrackedArray) = x.data
|
||||
tracker(x::TrackedArray) = x.tracker
|
||||
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
|
||||
|
||||
Base.convert(::Type{<:TrackedArray}, x::TrackedArray) =
|
||||
error("Not implemented: convert $(typeof(x)) to $T")
|
||||
|
||||
Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
|
||||
TrackedArray(convert(A, x))
|
||||
|
||||
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
@isdefined(A) ?
|
||||
print(io, "TrackedArray{…,$A}") :
|
||||
invoke(show, Tuple{IO,DataType}, io, t)
|
||||
|
||||
function Base.summary(io::IO, x::TrackedArray)
|
||||
print(io, "Tracked ")
|
||||
summary(io, data(x))
|
||||
end
|
||||
|
||||
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||
|
||||
function Base.show(io::IO, x::TrackedArray)
|
||||
show(io, data(x))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.copy(x::TrackedArray) = x
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
||||
|
||||
function update!(x::TrackedArray, Δ)
|
||||
x.data .+= data(Δ)
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
|
||||
size(data(x), i, j, is...)
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
for op in [:(==), :≈]
|
||||
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
Δ′[i...] = data(Δ)
|
||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||
|
||||
@grad function view(x::AbstractArray, inds...)
|
||||
view(data(x), inds...), function (Δ)
|
||||
grad_output = zero(x)
|
||||
subgrad = view(grad_output, inds...)
|
||||
subgrad[:] = data(Δ)
|
||||
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
S = size(xs)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
(nobacksies(:repeat, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
function combinations(xs, n)
|
||||
n < 1 && return [[]]
|
||||
cs = combinations(xs, n-1)
|
||||
[[x, c...] for x in xs, c in cs]
|
||||
end
|
||||
|
||||
combinations([AbstractArray, TrackedArray], 2)
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
@grad function vcat(xs...)
|
||||
vcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
d = Δ[start+1:start+size(xsi,1), i...]
|
||||
start += size(xsi, 1)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
@grad function hcat(xs...)
|
||||
hcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
d = if ndims(xsi) == 1
|
||||
Δ[:, start+1]
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
Δ[:, start+1:start+size(xsi,2), i...]
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
|
||||
track(cat, $(cnames...), x, xs..., dims = dims)
|
||||
end
|
||||
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
start = ntuple(i -> 0, Val(ndims(Δ)))
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||
start = start .+ till_xs
|
||||
d
|
||||
end for xs in Xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
mat1_rsh = reshape(mat1,(1,m1,1,n1))
|
||||
|
||||
m2, n2 = size(mat2)
|
||||
mat2_rsh = reshape(mat2,(m2,1,n2,1))
|
||||
|
||||
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
|
||||
end
|
||||
|
||||
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||
|
||||
|
||||
inv(A::TrackedArray) = Tracker.track(inv, A)
|
||||
@grad function inv(A)
|
||||
return inv(Tracker.data(A)), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * Ainv'
|
||||
return (∇A, )
|
||||
end
|
||||
end
|
||||
|
||||
# (/) rdivide
|
||||
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
||||
@grad function (A / B)
|
||||
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
||||
Binv = inv(B)
|
||||
∇B = - Binv' * A' * Δ * Binv'
|
||||
return (Δ * Binv', ∇B)
|
||||
end
|
||||
end
|
||||
|
||||
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
||||
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
||||
@grad function (A \ B)
|
||||
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * B' * Ainv'
|
||||
return (∇A, Ainv' * Δ)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, )
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||
|
||||
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||
|
||||
import LinearAlgebra: dot
|
||||
|
||||
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||
|
||||
@grad function maximum(xs; dims = dims)
|
||||
maximum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
@grad function minimum(xs; dims = dims)
|
||||
minimum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
|
||||
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
|
||||
|
||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
|
||||
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
||||
|
||||
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
||||
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
|
||||
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
|
||||
|
||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
|
||||
@grad depthwiseconv(x, w; kw...) =
|
||||
depthwiseconv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:depthwiseconv,
|
||||
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:conv,
|
||||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
# Broadcasting
|
||||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue, _) = nothing
|
||||
|
||||
dual(x, p) = x
|
||||
dual(x::Real, p) = Dual(x, p)
|
||||
|
||||
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||
return Δ * f(dargs...).partials[1]
|
||||
end
|
||||
|
||||
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||
y = broadcast(f, data.(args)...)
|
||||
eltype(y) <: Real || return y
|
||||
eltype(y) == Bool && return y
|
||||
function back(Δ)
|
||||
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
||||
dxs = map(unbroadcast, args, Δargs)
|
||||
return dxs
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||
|
||||
struct TrackedStyle <: BroadcastStyle end
|
||||
|
||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||
|
||||
# We have to re-build the original broadcast struct to get the appropriate array
|
||||
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||
broadcast_rebuild(xs) = data(xs)
|
||||
|
||||
broadcast_rebuild(bc::Broadcasted) =
|
||||
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||
|
||||
preprocess(x) = x
|
||||
|
||||
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||
bc1 = Broadcast.flatten(bc)
|
||||
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||
∇broadcast(bc2.f, bc1.args...)
|
||||
end
|
||||
|
||||
using Requires
|
||||
|
||||
# https://github.com/FluxML/Flux.jl/issues/353
|
||||
if VERSION < v"1.1.0-DEV.548"
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
isflat(bc) && return bc
|
||||
args = cat_nested(bc)
|
||||
let makeargs = make_makeargs(bc), f = bc.f
|
||||
newf = @inline function(args::Vararg{Any,N}) where N
|
||||
f(makeargs(args...)...)
|
||||
end
|
||||
return Broadcasted{Style}(newf, args, bc.axes)
|
||||
end
|
||||
end
|
||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||
bc = t[1]
|
||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||
let makeargs = make_makeargs(makeargs, bc.args)
|
||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||
return @inline function(args::Vararg{Any,N}) where N
|
||||
args1 = makeargs(args...)
|
||||
a, b = headargs(args1...), tailargs(args1...)
|
||||
(f(a...), b...)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
142
src/tracker/lib/real.jl
Normal file
142
src/tracker/lib/real.jl
Normal file
@ -0,0 +1,142 @@
|
||||
mutable struct TrackedReal{T<:Real} <: Real
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
|
||||
|
||||
data(x::TrackedReal) = x.data
|
||||
tracker(x::TrackedReal) = x.tracker
|
||||
|
||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||
|
||||
function back!(x::TrackedReal; once = true)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
return back!(x, 1, once = once)
|
||||
end
|
||||
|
||||
function update!(x::TrackedReal, Δ)
|
||||
x.data += data(Δ)
|
||||
tracker(x).grad = 0
|
||||
return x
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
T = get(io, :typeinfo, Any)
|
||||
show(io, data(x))
|
||||
T <: TrackedReal || print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
||||
|
||||
Base.copy(x::TrackedReal) = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
for op in [:(==), :≈, :<, :(<=)]
|
||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
Base.eps(::Type{TrackedReal{T}}) where T = eps(T)
|
||||
|
||||
for f in :[isinf, isnan, isfinite].args
|
||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||
end
|
||||
|
||||
Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...)
|
||||
|
||||
Base.float(x::TrackedReal) = x
|
||||
|
||||
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
||||
TrackedReal{promote_type(S,T)}
|
||||
|
||||
using DiffRules, SpecialFunctions, NaNMath
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 1 || continue
|
||||
@eval begin
|
||||
@grad $M.$f(a::Real) =
|
||||
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
|
||||
$M.$f(a::TrackedReal) = track($M.$f, a)
|
||||
end
|
||||
end
|
||||
|
||||
# Work around zero(π) not working, for some reason
|
||||
_zero(::Irrational) = nothing
|
||||
_zero(x) = zero(x)
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 2 || continue
|
||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
||||
f = :($M.$f)
|
||||
@eval begin
|
||||
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
|
||||
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
|
||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
|
||||
# Eliminating ambiguity
|
||||
import Base:^
|
||||
|
||||
^(a::TrackedReal, b::Integer) = track(^, a, b)
|
||||
|
||||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
data(xs::TrackedTuple) = xs.data
|
||||
tracker(xs::TrackedTuple) = xs.tracker
|
||||
|
||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
|
||||
@grad function getindex(xs::TrackedTuple, i)
|
||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, (tracker.(xs),)), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back_(c::Call{typeof(collect)}, Δ, once)
|
||||
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
|
||||
end
|
||||
|
||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||
end
|
18
src/tracker/numeric.jl
Normal file
18
src/tracker/numeric.jl
Normal file
@ -0,0 +1,18 @@
|
||||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zero.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp = x[i]
|
||||
x[i] = tmp - δ/2
|
||||
y1 = f(xs...)
|
||||
x[i] = tmp + δ/2
|
||||
y2 = f(xs...)
|
||||
x[i] = tmp
|
||||
Δ[i] = (y2-y1)/δ
|
||||
end
|
||||
return grads
|
||||
end
|
||||
|
||||
gradcheck(f, xs...) =
|
||||
all(isapprox.(ngradient(f, xs...),
|
||||
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))
|
71
src/treelike.jl
Normal file
71
src/treelike.jl
Normal file
@ -0,0 +1,71 @@
|
||||
import Adapt: adapt
|
||||
import .Tracker: IdSet
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
||||
children(x::Tuple) = x
|
||||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||
end
|
||||
end
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
|
||||
treelike(Base._current_module(), T, fs)
|
||||
end
|
||||
|
||||
macro treelike(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
function mapleaves(f, x; cache = IdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
function prefor(f, x; seen = IdSet())
|
||||
x ∈ seen && return
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
return
|
||||
end
|
||||
|
||||
function params(m)
|
||||
ps = Params()
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||
m)
|
||||
return ps
|
||||
end
|
||||
|
||||
params(m...) = params(m)
|
||||
|
||||
function loadparams!(m, xs)
|
||||
for (p, x) in zip(params(m), xs)
|
||||
size(p) == size(x) ||
|
||||
error("Expected param size $(size(p)), got $(size(x))")
|
||||
copyto!(data(p), data(x))
|
||||
end
|
||||
end
|
||||
|
||||
# CPU/GPU movement conveniences
|
||||
|
||||
cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
||||
|
||||
gpu_adaptor = identity
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
global gpu_adaptor = CuArrays.cu
|
||||
end
|
||||
|
||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
245
src/utils.jl
245
src/utils.jl
@ -1,41 +1,6 @@
|
||||
# Arrays
|
||||
nfan() = 1, 1 # fan_in, fan_out
|
||||
nfan(n) = 1, n # A vector is treated as a n×1 matrix
|
||||
nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
|
||||
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
|
||||
|
||||
"""
|
||||
glorot_uniform(dims...)
|
||||
|
||||
Return an `Array` of size `dims` containing random variables taken from a uniform
|
||||
distribution in the interval ``[-x, x]``, where `x = sqrt(24 / sum(dims)) / 2`.
|
||||
|
||||
# Examples
|
||||
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||
julia> Flux.glorot_uniform(2, 3)
|
||||
2×3 Array{Float32,2}:
|
||||
0.601094 -0.57414 -0.814925
|
||||
0.900868 0.805994 0.057514
|
||||
```
|
||||
"""
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
||||
|
||||
"""
|
||||
glorot_normal(dims...)
|
||||
|
||||
Return an `Array` of size `dims` containing random variables taken from a normal
|
||||
distribution with mean 0 and standard deviation `sqrt(2 / sum(dims))`.
|
||||
|
||||
# Examples
|
||||
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||
julia> Flux.glorot_normal(3, 2)
|
||||
3×2 Array{Float32,2}:
|
||||
0.429505 -0.0852891
|
||||
0.523935 0.371009
|
||||
-0.223261 0.188052
|
||||
```
|
||||
"""
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
|
||||
|
||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||
@ -43,98 +8,19 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||
ones(dims...) = Base.ones(Float32, dims...)
|
||||
zeros(dims...) = Base.zeros(Float32, dims...)
|
||||
|
||||
"""
|
||||
unsqueeze(xs, dim)
|
||||
|
||||
Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
|
||||
where `dim` indicates in which dimension `xs` is extended.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> xs = [[1, 2], [3, 4], [5, 6]]
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
|
||||
julia> Flux.unsqueeze(xs, 1)
|
||||
1×3 Array{Array{Int64,1},2}:
|
||||
[1, 2] [3, 4] [5, 6]
|
||||
|
||||
julia> Flux.unsqueeze([1 2; 3 4], 2)
|
||||
2×1×2 Array{Int64,3}:
|
||||
[:, :, 1] =
|
||||
1
|
||||
3
|
||||
|
||||
[:, :, 2] =
|
||||
2
|
||||
4
|
||||
```
|
||||
"""
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
"""
|
||||
stack(xs, dim)
|
||||
|
||||
Concatenate the given `Array` of `Array`s `xs` into a single `Array` along the
|
||||
given dimension `dim`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> xs = [[1, 2], [3, 4], [5, 6]]
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
|
||||
julia> Flux.stack(xs, 1)
|
||||
3×2 Array{Int64,2}:
|
||||
1 2
|
||||
3 4
|
||||
5 6
|
||||
|
||||
julia> cat(xs, dims=1)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
```
|
||||
"""
|
||||
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
|
||||
|
||||
"""
|
||||
unstack(xs, dim)
|
||||
|
||||
Unroll the given `xs` into an `Array` of `Array`s along the given dimension `dim`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
|
||||
4-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
[7, 8]
|
||||
```
|
||||
"""
|
||||
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
|
||||
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
|
||||
"""
|
||||
chunk(xs, n)
|
||||
|
||||
Split `xs` into `n` parts.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.chunk(1:10, 3)
|
||||
3-element Array{UnitRange{Int64},1}:
|
||||
1:4
|
||||
5:8
|
||||
9:10
|
||||
|
||||
julia> Flux.chunk(collect(1:10), 3)
|
||||
3-element Array{SubArray{Int64,1,Array{Int64,1},Tuple{UnitRange{Int64}},true},1}:
|
||||
```julia
|
||||
julia> chunk(1:10, 3)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2, 3, 4]
|
||||
[5, 6, 7, 8]
|
||||
[9, 10]
|
||||
@ -149,12 +35,11 @@ batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
|
||||
|
||||
Count the number of times that each element of `xs` appears.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.frequencies(['a','b','b'])
|
||||
```julia
|
||||
julia> frequencies(['a','b','b'])
|
||||
Dict{Char,Int64} with 2 entries:
|
||||
'a' => 1
|
||||
'b' => 2
|
||||
'a' => 1
|
||||
```
|
||||
"""
|
||||
function frequencies(xs)
|
||||
@ -170,13 +55,12 @@ head(x::Tuple) = reverse(Base.tail(reverse(x)))
|
||||
squeezebatch(x) = reshape(x, head(size(x)))
|
||||
|
||||
"""
|
||||
batch(xs)
|
||||
batch(xs)
|
||||
|
||||
Batch the arrays in `xs` into a single array.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.batch([[1,2,3],[4,5,6]])
|
||||
```julia
|
||||
julia> batch([[1,2,3],[4,5,6]])
|
||||
3×2 Array{Int64,2}:
|
||||
1 4
|
||||
2 5
|
||||
@ -193,25 +77,6 @@ function batch(xs)
|
||||
return data
|
||||
end
|
||||
|
||||
"""
|
||||
Return the given sequence padded with `p` up to a maximum length of `n`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> rpad([1, 2], 4, 0)
|
||||
4-element Array{Int64,1}:
|
||||
1
|
||||
2
|
||||
0
|
||||
0
|
||||
|
||||
julia> rpad([1, 2, 3], 2, 0)
|
||||
3-element Array{Int64,1}:
|
||||
1
|
||||
2
|
||||
3
|
||||
```
|
||||
"""
|
||||
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
|
||||
|
||||
"""
|
||||
@ -220,9 +85,8 @@ Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))
|
||||
Take a list of `N` sequences, and turn them into a single sequence where each
|
||||
item is a batch of `N`. Short sequences will be padded by `pad`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
|
||||
```julia
|
||||
julia> batchseq([[1, 2, 3], [4, 5]], 0)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 4]
|
||||
[2, 5]
|
||||
@ -234,64 +98,14 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
|
||||
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
|
||||
end
|
||||
|
||||
# Flattening models to weight vectors, and back
|
||||
|
||||
function _restructure(m, xs)
|
||||
i = 0
|
||||
fmap(m) do x
|
||||
x isa AbstractArray || return x
|
||||
x = reshape(xs[i.+(1:length(x))], size(x))
|
||||
i += length(x)
|
||||
return x
|
||||
end
|
||||
end
|
||||
|
||||
@adjoint function _restructure(m, xs)
|
||||
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
|
||||
end
|
||||
|
||||
"""
|
||||
destructure(m)
|
||||
|
||||
Flatten a model's parameters into a single weight vector.
|
||||
|
||||
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
|
||||
julia> θ, re = destructure(m);
|
||||
|
||||
julia> θ
|
||||
67-element Array{Float32,1}:
|
||||
-0.1407104
|
||||
...
|
||||
|
||||
The second return value `re` allows you to reconstruct the original network after making
|
||||
modifications to the weight vector (for example, with a hypernetwork).
|
||||
|
||||
julia> re(θ .* 2)
|
||||
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
"""
|
||||
function destructure(m)
|
||||
xs = Zygote.Buffer([])
|
||||
fmap(m) do x
|
||||
x isa AbstractArray && push!(xs, x)
|
||||
return x
|
||||
end
|
||||
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
|
||||
end
|
||||
|
||||
# Other
|
||||
|
||||
"""
|
||||
throttle(f, timeout; leading=true, trailing=false)
|
||||
|
||||
Return a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds.
|
||||
|
||||
Normally, the throttled function will run as much as it can, without ever
|
||||
going more than once per `wait` duration; but if you'd like to disable the
|
||||
execution on the leading edge, pass `leading=false`. To enable execution on
|
||||
the trailing edge, pass `trailing=true`.
|
||||
Returns a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds. Normally, the throttled function will run
|
||||
as much as it can, without ever going more than once per `wait` duration;
|
||||
but if you'd like to disable the execution on the leading edge, pass
|
||||
`leading=false`. To enable execution on the trailing edge, ditto.
|
||||
"""
|
||||
function throttle(f, timeout; leading=true, trailing=false)
|
||||
cooldown = true
|
||||
@ -325,6 +139,25 @@ function throttle(f, timeout; leading=true, trailing=false)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J'
|
||||
end
|
||||
|
||||
"""
|
||||
@jit ...
|
||||
|
||||
|
106
src/zeros.jl
106
src/zeros.jl
@ -1,106 +0,0 @@
|
||||
import Base: +, -, *, reshape, size
|
||||
import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
|
||||
|
||||
"""
|
||||
Zeros()
|
||||
Zeros(size...)
|
||||
Zeros(Type, size...)
|
||||
|
||||
Acts as a stand-in for an array of zeros that can be
|
||||
used during training which is ignored by the optimisers.
|
||||
|
||||
Useful to turn bias off for a forward pass of a layer.
|
||||
|
||||
## Examples
|
||||
|
||||
```julia
|
||||
julia> Flux.Zeros(3,3)
|
||||
3×3 Flux.Zeros{Bool,2}:
|
||||
false false false
|
||||
false false false
|
||||
false false false
|
||||
|
||||
julia> Flux.Zeros(Float32, 3,3)
|
||||
3×3 Flux.Zeros{Float32,2}:
|
||||
0.0 0.0 0.0
|
||||
0.0 0.0 0.0
|
||||
0.0 0.0 0.0
|
||||
|
||||
julia> rand(3,3) .+ Flux.Zeros()
|
||||
3×3 Array{Float64,2}:
|
||||
0.198739 0.490459 0.785386
|
||||
0.779074 0.39986 0.66383
|
||||
0.854981 0.447292 0.314497
|
||||
|
||||
julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
||||
Conv((2, 2), 1=>3)
|
||||
```
|
||||
"""
|
||||
struct Zeros{T,N} <: AbstractArray{T,N}
|
||||
size::Tuple
|
||||
end
|
||||
|
||||
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
|
||||
Zeros(sz::Integer...) = Zeros(Bool, sz...)
|
||||
|
||||
Base.size(xs::Zeros) = xs.size
|
||||
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
|
||||
|
||||
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
|
||||
|
||||
Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T)
|
||||
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
|
||||
Zeros(T, length(inds))
|
||||
|
||||
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
|
||||
|
||||
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
||||
reshape(xs, dims...), _ -> nothing
|
||||
|
||||
# Define basic ops
|
||||
for f in (:+, :-)
|
||||
@eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros)
|
||||
@assert size(a) == size(b) throw(DimensionMismatch("dimensions must match"))
|
||||
a
|
||||
end
|
||||
end
|
||||
|
||||
+(a::Zeros, b::AbstractArray) = b + a
|
||||
-(a::Zeros, b::AbstractArray) = -b + a
|
||||
|
||||
Base.copy(xs::Zeros{T,N}) where {T,N} = xs
|
||||
|
||||
# Define broadcasting behaviour
|
||||
for op in (:+, :-)
|
||||
@eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros)
|
||||
bs = Broadcast.broadcast_shape(size(a), size(b))
|
||||
size(a) == bs && return a
|
||||
sz = similar(a, bs)
|
||||
sz .= a
|
||||
end
|
||||
end
|
||||
|
||||
broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a)
|
||||
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a)
|
||||
|
||||
function broadcasted(::typeof(*), a::AbstractArray, b::Zeros)
|
||||
Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
|
||||
end
|
||||
|
||||
broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(*, b, a)
|
||||
|
||||
for op in (:+, :-, :*)
|
||||
@eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
|
||||
end
|
||||
|
||||
# Some opportunities to avoid scalar indexing, intermediaries
|
||||
# Since it replicates a little of what we expect Base to do,
|
||||
# it should be possible to remove in the future, but for now,
|
||||
# these help with performance.
|
||||
broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a
|
||||
broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b
|
||||
broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a
|
||||
broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b
|
||||
broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a)
|
||||
broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
|
||||
broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
|
@ -1,5 +1,4 @@
|
||||
using Flux, Test
|
||||
using Flux.CuArrays
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux: gpu
|
||||
|
||||
@info "Testing GPU Support"
|
||||
@ -8,11 +7,9 @@ using Flux: gpu
|
||||
|
||||
CuArrays.allowscalar(false)
|
||||
|
||||
x = randn(5, 5)
|
||||
x = param(randn(5, 5))
|
||||
cx = gpu(x)
|
||||
@test cx isa CuArray
|
||||
|
||||
@test Flux.onecold(gpu([1.0, 2.0, 3.0])) == 3
|
||||
@test cx isa TrackedArray && cx.data isa CuArray
|
||||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@ -22,54 +19,25 @@ cx = gpu(x)
|
||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
cm = gpu(m)
|
||||
|
||||
@test all(p isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
|
||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
x = [1.,2.,3.]
|
||||
x = [1,2,3]
|
||||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0)
|
||||
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
|
||||
|
||||
x = [-1.1491, 0.8619, 0.3127]
|
||||
y = [1, 1, 0.]
|
||||
@test Flux.binarycrossentropy.(σ.(x),y) ≈ Array(Flux.binarycrossentropy.(cu(σ.(x)),cu(y)))
|
||||
@test Flux.logitbinarycrossentropy.(x,y) ≈ Array(Flux.logitbinarycrossentropy.(cu(x),cu(y)))
|
||||
|
||||
xs = rand(5, 5)
|
||||
xs = param(rand(5,5))
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
x = gpu(rand(10, 10, 3, 2))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||
|
||||
c = gpu(CrossCor((2,2),3=>4))
|
||||
x = gpu(rand(10, 10, 3, 2))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||
Flux.back!(sum(l))
|
||||
|
||||
end
|
||||
|
||||
@testset "onecold gpu" begin
|
||||
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
||||
@test Flux.onecold(y) isa CuArray
|
||||
@test y[3,:] isa CuArray
|
||||
end
|
||||
|
||||
@testset "restructure gpu" begin
|
||||
dudt = Dense(1,1) |> gpu
|
||||
p,re = Flux.destructure(dudt)
|
||||
foo(x) = sum(re(p)(x))
|
||||
@test gradient(foo, cu(rand(1)))[1] isa CuArray
|
||||
end
|
||||
|
||||
if CuArrays.has_cudnn()
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
include("curnn.jl")
|
||||
include("layers.jl")
|
||||
else
|
||||
@warn "CUDNN unavailable, not testing GPU DNN support"
|
||||
if CuArrays.libcudnn != nothing
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
include("curnn.jl")
|
||||
end
|
||||
|
@ -1,44 +1,48 @@
|
||||
using Flux, CuArrays, Test
|
||||
using Flux: pullback
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux.Tracker: TrackedArray, data
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
@testset "4D Input" begin
|
||||
x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
|
||||
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
|
||||
m = BatchNorm(3)
|
||||
cx = gpu(x)
|
||||
cm = gpu(m)
|
||||
|
||||
y, back = pullback((m, x) -> m(x), m, x)
|
||||
cy, cback = pullback((m, x) -> m(x), cm, cx)
|
||||
y = m(x)
|
||||
cy = cm(cx)
|
||||
|
||||
@test cpu(cy) ≈ y
|
||||
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||
|
||||
Δ = randn(size(y))
|
||||
dm, dx = back(Δ)
|
||||
cdm, cdx = cback(gpu(Δ))
|
||||
@test cpu(data(cy)) ≈ data(y)
|
||||
|
||||
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||
@test dm[].β ≈ cpu(cdm[].β)
|
||||
@test dx ≈ cpu(cdx)
|
||||
g = rand(size(y)...)
|
||||
Flux.back!(y, g)
|
||||
Flux.back!(cy, gpu(g))
|
||||
|
||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||
@test x.grad ≈ cpu(x.grad)
|
||||
end
|
||||
|
||||
@testset "2D Input" begin
|
||||
x = Float64.(collect(reshape(1:12, 3, 4)))
|
||||
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
|
||||
m = BatchNorm(3)
|
||||
cx = gpu(x)
|
||||
cm = gpu(m)
|
||||
|
||||
y, back = pullback((m, x) -> m(x), m, x)
|
||||
cy, cback = pullback((m, x) -> m(x), cm, cx)
|
||||
y = m(x)
|
||||
cy = cm(cx)
|
||||
|
||||
@test cpu(cy) ≈ y
|
||||
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
Δ = randn(size(y))
|
||||
dm, dx = back(Δ)
|
||||
cdm, cdx = cback(gpu(Δ))
|
||||
@test cpu(data(cy)) ≈ data(y)
|
||||
|
||||
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||
@test dm[].β ≈ cpu(cdm[].β)
|
||||
@test dx ≈ cpu(cdx)
|
||||
g = rand(size(y)...)
|
||||
Flux.back!(y, g)
|
||||
Flux.back!(cy, gpu(g))
|
||||
|
||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||
@test x.grad ≈ cpu(x.grad)
|
||||
end
|
||||
end
|
||||
|
@ -1,63 +1,46 @@
|
||||
using Flux, CuArrays, Test
|
||||
using Flux: pullback
|
||||
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
m = R(10, 5) |> gpu
|
||||
x = gpu(rand(10))
|
||||
(m̄,) = gradient(m -> sum(m(x)), m)
|
||||
Flux.reset!(m)
|
||||
θ = gradient(() -> sum(m(x)), params(m))
|
||||
@test collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
|
||||
end
|
||||
|
||||
@testset "RNN" begin
|
||||
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
rnn = R(10, 5)
|
||||
curnn = fmap(gpu, rnn)
|
||||
curnn = mapleaves(gpu, rnn)
|
||||
@testset for batch_size in (1, 5)
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
x = batch_size == 1 ?
|
||||
param(rand(10)) :
|
||||
param(rand(10,batch_size))
|
||||
cux = gpu(x)
|
||||
y = (rnn(x); rnn(x))
|
||||
cuy = (curnn(cux); curnn(cux))
|
||||
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
x = batch_size == 1 ?
|
||||
rand(10) :
|
||||
rand(10, batch_size)
|
||||
cux = gpu(x)
|
||||
@test y.data ≈ collect(cuy.data)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
|
||||
y, back = pullback((r, x) -> r(x), rnn, x)
|
||||
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)
|
||||
Δ = randn(size(y))
|
||||
|
||||
@test y ≈ collect(cuy)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
Flux.back!(y, Δ)
|
||||
Flux.back!(cuy, gpu(Δ))
|
||||
|
||||
ȳ = randn(size(y))
|
||||
m̄, x̄ = back(ȳ)
|
||||
cum̄, cux̄ = cuback(gpu(ȳ))
|
||||
|
||||
m̄[].cell[].Wi
|
||||
|
||||
m̄[].state
|
||||
cum̄[].state
|
||||
|
||||
@test x̄ ≈ collect(cux̄)
|
||||
@test m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi)
|
||||
@test m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh)
|
||||
@test m̄[].cell[].b ≈ collect(cum̄[].cell[].b)
|
||||
if m̄[].state isa Tuple
|
||||
for (x, cx) in zip(m̄[].state, cum̄[].state)
|
||||
@test x ≈ collect(cx)
|
||||
@test x.grad ≈ collect(cux.grad)
|
||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
||||
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
||||
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
||||
if isdefined(rnn.cell, :c)
|
||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||
end
|
||||
else
|
||||
@test m̄[].state ≈ collect(cum̄[].state)
|
||||
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
ohx = batch_size == 1 ?
|
||||
Flux.onehot(rand(1:10), 1:10) :
|
||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||
cuohx = gpu(ohx)
|
||||
y = (rnn(ohx); rnn(ohx))
|
||||
cuy = (curnn(cuohx); curnn(cuohx))
|
||||
|
||||
@test y.data ≈ collect(cuy.data)
|
||||
end
|
||||
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
ohx = batch_size == 1 ?
|
||||
Flux.onehot(rand(1:10), 1:10) :
|
||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||
cuohx = gpu(ohx)
|
||||
y = (rnn(ohx); rnn(ohx))
|
||||
cuy = (curnn(cuohx); curnn(cuohx))
|
||||
|
||||
@test y ≈ collect(cuy)
|
||||
end
|
||||
end
|
||||
|
@ -1,98 +0,0 @@
|
||||
# Test layers and data/model movements on and off the GPU
|
||||
# Add tests for layers and their gradients on the GPU
|
||||
# Most of the forward passes should be fine being applied
|
||||
# to bitstype objects, but this gives higher coverage for our use-cases
|
||||
# Check that getting the gradients does not throw
|
||||
|
||||
# generic movement tests
|
||||
@testset "Basic GPU Movement" begin
|
||||
@test gradient(x -> sum(gpu(x)), rand(3,3)) isa Tuple
|
||||
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
|
||||
end
|
||||
|
||||
# TODO: These layers get into scalar indexing
|
||||
# `AlphaDropout` throws a compilation error on GPUs,
|
||||
# whereas, the rest are scalar indexing issues.
|
||||
const BROKEN_LAYERS = [DepthwiseConv,
|
||||
AlphaDropout,
|
||||
InstanceNorm,
|
||||
GroupNorm]
|
||||
|
||||
function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
||||
isnothing(xs) && error("Missing input to test the layers against.")
|
||||
@testset "$name GPU grad tests" begin
|
||||
for layer in layers
|
||||
@testset "$layer GPU grad test" begin
|
||||
l = gpu(layer(args...))
|
||||
xs = gpu(xs)
|
||||
if any(x -> isa(l, x), BROKEN_LAYERS)
|
||||
ps = Flux.params(l)
|
||||
@test_broken gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
else
|
||||
ps = Flux.params(l)
|
||||
@test gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
gs = gradient(() -> sum(l(xs)), ps)
|
||||
|
||||
# Handle pooling layers
|
||||
if !isempty(ps)
|
||||
@test gs[first(ps)] isa Flux.CuArrays.CuArray
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Repeats from Conv, CrossCor
|
||||
|
||||
r = rand(Float32, 28, 28, 1, 1)
|
||||
conv_layers = [Conv, ConvTranspose, CrossCor, DepthwiseConv]
|
||||
gradtest("Conv", conv_layers, r, (2,2), 1=>3)
|
||||
|
||||
pooling_layers = [MaxPool, MeanPool]
|
||||
gradtest("Pooling", pooling_layers, r, (2,2))
|
||||
|
||||
dropout_layers = [Dropout, AlphaDropout]
|
||||
gradtest("Dropout", dropout_layers, r, 0.5f0)
|
||||
|
||||
norm_layers = [LayerNorm, BatchNorm]
|
||||
gradtest("Normalising", norm_layers, rand(Float32, 28,28,3,1), 1)
|
||||
|
||||
instancenorm = [InstanceNorm]
|
||||
gradtest("InstanceNorm", instancenorm, r, 1)
|
||||
|
||||
groupnorm = [GroupNorm]
|
||||
gradtest("GroupNorm", groupnorm, rand(Float32, 28,28,3,1), 3, 1)
|
||||
|
||||
const stateless_layers = [Flux.mse,
|
||||
Flux.crossentropy,
|
||||
Flux.logitcrossentropy,
|
||||
Flux.normalise]
|
||||
|
||||
const stateless_layers_broadcasted = [Flux.binarycrossentropy,
|
||||
Flux.logitbinarycrossentropy]
|
||||
|
||||
function stateless_gradtest(f, args...)
|
||||
@test gradient((args...) -> sum(f(args...)), args...)[1] isa CuArray
|
||||
end
|
||||
|
||||
function stateless_gradtest_broadcasted(f, args...)
|
||||
@test gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
end
|
||||
|
||||
@testset "Stateless GPU grad tests" begin
|
||||
x = gpu(rand(3,3))
|
||||
y = gpu(rand(3,3))
|
||||
|
||||
for layer in stateless_layers
|
||||
if layer == Flux.normalise
|
||||
stateless_gradtest(layer, x)
|
||||
else
|
||||
stateless_gradtest(layer, x, y)
|
||||
end
|
||||
end
|
||||
|
||||
for layer in stateless_layers_broadcasted
|
||||
stateless_gradtest_broadcasted(layer, x, y)
|
||||
end
|
||||
end
|
120
test/data.jl
120
test/data.jl
@ -1,116 +1,16 @@
|
||||
@testset "DataLoader" begin
|
||||
X = reshape([1:10;], (2, 5))
|
||||
Y = [1:5;]
|
||||
using Flux.Data
|
||||
using Test
|
||||
|
||||
d = DataLoader(X, batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 3
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
@test batches[3] == X[:,5:5]
|
||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||
|
||||
d = DataLoader(X, batchsize=2, partial=false)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
@test length(CMUDict.phones()) == 39
|
||||
|
||||
d = DataLoader((X,), batchsize=2, partial=false)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == (X[:,1:2],)
|
||||
@test batches[2] == (X[:,3:4],)
|
||||
@test length(CMUDict.symbols()) == 84
|
||||
|
||||
d = DataLoader((X, Y), batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||
@test length(batches) == 3
|
||||
@test length(batches[1]) == 2
|
||||
@test length(batches[2]) == 2
|
||||
@test length(batches[3]) == 2
|
||||
@test batches[1][1] == X[:,1:2]
|
||||
@test batches[1][2] == Y[1:2]
|
||||
@test batches[2][1] == X[:,3:4]
|
||||
@test batches[2][2] == Y[3:4]
|
||||
@test batches[3][1] == X[:,5:5]
|
||||
@test batches[3][2] == Y[5:5]
|
||||
@test MNIST.images()[1] isa Matrix
|
||||
@test MNIST.labels() isa Vector{Int64}
|
||||
|
||||
# test with NamedTuple
|
||||
d = DataLoader((x=X, y=Y), batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
||||
@test length(batches) == 3
|
||||
@test length(batches[1]) == 2
|
||||
@test length(batches[2]) == 2
|
||||
@test length(batches[3]) == 2
|
||||
@test batches[1][1] == batches[1].x == X[:,1:2]
|
||||
@test batches[1][2] == batches[1].y == Y[1:2]
|
||||
@test batches[2][1] == batches[2].x == X[:,3:4]
|
||||
@test batches[2][2] == batches[2].y == Y[3:4]
|
||||
@test batches[3][1] == batches[3].x == X[:,5:5]
|
||||
@test batches[3][2] == batches[3].y == Y[5:5]
|
||||
@test FashionMNIST.images()[1] isa Matrix
|
||||
@test FashionMNIST.labels() isa Vector{Int64}
|
||||
|
||||
# test interaction with `train!`
|
||||
θ = ones(2)
|
||||
X = zeros(2, 10)
|
||||
loss(x) = sum((x .- θ).^2)
|
||||
d = DataLoader(X)
|
||||
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
||||
@test norm(θ) < 1e-4
|
||||
|
||||
# test interaction with `train!`
|
||||
θ = zeros(2)
|
||||
X = ones(2, 10)
|
||||
Y = fill(2, 10)
|
||||
loss(x, y) = sum((y - x'*θ).^2)
|
||||
d = DataLoader((X, Y))
|
||||
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
||||
@test norm(θ .- 1) < 1e-10
|
||||
end
|
||||
|
||||
@testset "CMUDict" begin
|
||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||
|
||||
@test length(CMUDict.phones()) == 39
|
||||
|
||||
@test length(CMUDict.symbols()) == 84
|
||||
end
|
||||
|
||||
@testset "MNIST" begin
|
||||
@test MNIST.images()[1] isa Matrix
|
||||
@test MNIST.labels() isa Vector{Int64}
|
||||
end
|
||||
|
||||
@testset "FashionMNIST" begin
|
||||
@test FashionMNIST.images()[1] isa Matrix
|
||||
@test FashionMNIST.labels() isa Vector{Int64}
|
||||
end
|
||||
|
||||
@testset "Sentiment" begin
|
||||
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
|
||||
end
|
||||
|
||||
@testset "Iris" begin
|
||||
@test Iris.features() isa Matrix
|
||||
@test size(Iris.features()) == (4,150)
|
||||
|
||||
@test Iris.labels() isa Vector{String}
|
||||
@test size(Iris.labels()) == (150,)
|
||||
end
|
||||
|
||||
|
||||
@testset "Housing" begin
|
||||
@test Housing.features() isa Matrix # test broken due to SSL certifate expiration problem
|
||||
@test size(Housing.features()) == (506, 13)
|
||||
|
||||
@test Housing.targets() isa Array{Float64}
|
||||
@test size(Housing.targets()) == (506, 1)
|
||||
end
|
||||
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
|
||||
|
@ -1,117 +1,33 @@
|
||||
using Test, Random
|
||||
import Flux: activations
|
||||
|
||||
@testset "basic" begin
|
||||
@testset "helpers" begin
|
||||
@testset "activations" begin
|
||||
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x))
|
||||
x = randn(10)
|
||||
@test activations(dummy_model, x)[1] == x.^2
|
||||
@test activations(dummy_model, x)[2] == (x.^2 .- 3)
|
||||
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3)
|
||||
|
||||
@test activations(Chain(), x) == ()
|
||||
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Chain" begin
|
||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||
# numeric test should be put into testset of corresponding layer
|
||||
end
|
||||
|
||||
@testset "Activations" begin
|
||||
c = Chain(Dense(3,5,relu), Dense(5,1,relu))
|
||||
X = Float32.([1.0; 1.0; 1.0])
|
||||
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c))
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@testset "constructors" begin
|
||||
@test size(Dense(10, 100).W) == (100, 10)
|
||||
@test Dense(rand(100,10), rand(10)).σ == identity
|
||||
|
||||
@test_throws MethodError Dense(10, 10.5)
|
||||
@test_throws MethodError Dense(10, 10.5, tanh)
|
||||
@testset "Chain" begin
|
||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||
# numeric test should be put into testset of corresponding layer
|
||||
end
|
||||
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
||||
@testset "Dense" begin
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
||||
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
end
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
@testset "Diagonal" begin
|
||||
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||
@test length(Flux.Diagonal(10)(1)) == 10
|
||||
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||
|
||||
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
||||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||
end
|
||||
|
||||
@testset "Maxout" begin
|
||||
# Note that the normal common usage of Maxout is as per the docstring
|
||||
# These are abnormal constructors used for testing purposes
|
||||
|
||||
@testset "Constructor" begin
|
||||
mo = Maxout(() -> identity, 4)
|
||||
input = rand(40)
|
||||
@test mo(input) == input
|
||||
end
|
||||
|
||||
@testset "simple alternatives" begin
|
||||
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
|
||||
input = rand(40)
|
||||
@test mo(input) == 2*input
|
||||
@testset "Diagonal" begin
|
||||
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||
@test length(Flux.Diagonal(10)(1)) == 10
|
||||
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||
|
||||
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
||||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||
end
|
||||
|
||||
@testset "complex alternatives" begin
|
||||
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
|
||||
input = [3.0 2.0]
|
||||
target = [0.5, 0.7].*input
|
||||
@test mo(input) == target
|
||||
end
|
||||
|
||||
@testset "params" begin
|
||||
mo = Maxout(()->Dense(32, 64), 4)
|
||||
ps = params(mo)
|
||||
@test length(ps) == 8 #4 alts, each with weight and bias
|
||||
end
|
||||
end
|
||||
|
||||
@testset "SkipConnection" begin
|
||||
@testset "zero sum" begin
|
||||
input = randn(10, 10, 10, 10)
|
||||
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
|
||||
end
|
||||
|
||||
@testset "concat size" begin
|
||||
input = randn(10, 2)
|
||||
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "output dimensions" begin
|
||||
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||
@test Flux.outdims(m, (10, 10)) == (6, 6)
|
||||
|
||||
m = Dense(10, 5)
|
||||
@test Flux.outdims(m, (5, 2)) == (5,)
|
||||
@test Flux.outdims(m, (10,)) == (5,)
|
||||
|
||||
m = Flux.Diagonal(10)
|
||||
@test Flux.outdims(m, (10,)) == (10,)
|
||||
|
||||
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
end
|
||||
end
|
||||
|
@ -1,17 +1,12 @@
|
||||
using Flux, Test
|
||||
using Flux: maxpool, meanpool
|
||||
using Flux: gradient
|
||||
|
||||
@testset "Pooling" begin
|
||||
x = randn(Float32, 10, 10, 3, 2)
|
||||
gmp = GlobalMaxPool()
|
||||
@test size(gmp(x)) == (1, 1, 3, 2)
|
||||
gmp = GlobalMeanPool()
|
||||
@test size(gmp(x)) == (1, 1, 3, 2)
|
||||
mp = MaxPool((2, 2))
|
||||
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
||||
@test mp(x) == maxpool(x, (2,2))
|
||||
mp = MeanPool((2, 2))
|
||||
@test mp(x) == meanpool(x, PoolDims(x, 2))
|
||||
@test mp(x) == meanpool(x, (2,2))
|
||||
end
|
||||
|
||||
@testset "CNN" begin
|
||||
@ -25,194 +20,4 @@ end
|
||||
Dense(288, 10), softmax)
|
||||
|
||||
@test size(m(r)) == (10, 5)
|
||||
|
||||
# Test bias switch
|
||||
bias = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3))
|
||||
ip = zeros(Float32, 28,28,1,1)
|
||||
|
||||
op = bias(ip)
|
||||
@test sum(op) == prod(size(op))
|
||||
|
||||
bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
||||
op = bias(ip)
|
||||
@test sum(op) === 0.f0
|
||||
gs = gradient(() -> sum(bias(ip)), Flux.params(bias))
|
||||
@test gs[bias.bias] == nothing
|
||||
|
||||
# Train w/o bias and make sure no convergence happens
|
||||
# when only bias can be converged
|
||||
bias = Conv((2, 2), 1=>3, bias = Flux.Zeros());
|
||||
ip = zeros(Float32, 28,28,1,1)
|
||||
op = zeros(Float32, 27,27,3,1) .+ 2.f0
|
||||
opt = Descent()
|
||||
|
||||
for _ = 1:10^3
|
||||
gs = gradient(params(bias)) do
|
||||
Flux.mse(bias(ip), op)
|
||||
end
|
||||
Flux.Optimise.update!(opt, params(bias), gs)
|
||||
end
|
||||
|
||||
@test Flux.mse(bias(ip), op) ≈ 4.f0
|
||||
end
|
||||
|
||||
@testset "asymmetric padding" begin
|
||||
r = ones(Float32, 28, 28, 1, 1)
|
||||
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
|
||||
m.weight[:] .= 1.0
|
||||
m.bias[:] .= 0.0
|
||||
y_hat = m(r)[:,:,1,1]
|
||||
@test size(y_hat) == (27, 29)
|
||||
@test y_hat[1, 1] ≈ 6.0
|
||||
@test y_hat[2, 2] ≈ 9.0
|
||||
@test y_hat[end, 1] ≈ 4.0
|
||||
@test y_hat[1, end] ≈ 3.0
|
||||
@test y_hat[1, end-1] ≈ 6.0
|
||||
@test y_hat[end, end] ≈ 2.0
|
||||
end
|
||||
|
||||
@testset "Depthwise Conv" begin
|
||||
r = zeros(Float32, 28, 28, 3, 5)
|
||||
m1 = DepthwiseConv((2, 2), 3=>15)
|
||||
@test size(m1(r), 3) == 15
|
||||
|
||||
m3 = DepthwiseConv((2, 3), 3=>9)
|
||||
@test size(m3(r), 3) == 9
|
||||
|
||||
# Test that we cannot ask for non-integer multiplication factors
|
||||
@test_throws AssertionError DepthwiseConv((2,2), 3=>10)
|
||||
end
|
||||
|
||||
@testset "ConvTranspose" begin
|
||||
x = zeros(Float32, 28, 28, 1, 1)
|
||||
y = Conv((3,3), 1 => 1)(x)
|
||||
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||
@test size(x_hat) == size(x)
|
||||
|
||||
m = ConvTranspose((3,3), 1=>1)
|
||||
# Test that the gradient call does not throw: #900
|
||||
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
|
||||
end
|
||||
|
||||
@testset "CrossCor" begin
|
||||
x = rand(Float32, 28, 28, 1, 1)
|
||||
w = rand(2,2,1,1)
|
||||
y = CrossCor(w, [0.0])
|
||||
|
||||
@test isapprox(sum(w .* x[1:2, 1:2, :, :]), y(x)[1, 1, 1, 1], rtol=1e-7)
|
||||
|
||||
r = zeros(Float32, 28, 28, 1, 5)
|
||||
m = Chain(
|
||||
CrossCor((2, 2), 1=>16, relu),
|
||||
MaxPool((2,2)),
|
||||
CrossCor((2, 2), 16=>8, relu),
|
||||
MaxPool((2,2)),
|
||||
x -> reshape(x, :, size(x, 4)),
|
||||
Dense(288, 10), softmax)
|
||||
|
||||
@test size(m(r)) == (10, 5)
|
||||
@test y(x) != Conv(w, [0.0])(x)
|
||||
@test CrossCor(w[end:-1:1, end:-1:1, :, :], [0.0])(x) == Conv(w, [0.0])(x)
|
||||
end
|
||||
|
||||
@testset "Conv with non quadratic window #700" begin
|
||||
data = zeros(Float32, 7,7,1,1)
|
||||
data[4,4,1,1] = 1
|
||||
|
||||
l = Conv((3,3), 1=>1)
|
||||
expected = zeros(eltype(l.weight),5,5,1,1)
|
||||
expected[2:end-1,2:end-1,1,1] = l.weight
|
||||
@test expected ≈ l(data)
|
||||
|
||||
l = Conv((3,1), 1=>1)
|
||||
expected = zeros(eltype(l.weight),5,7,1,1)
|
||||
expected[2:end-1,4,1,1] = l.weight
|
||||
@test expected ≈ l(data)
|
||||
|
||||
l = Conv((1,3), 1=>1)
|
||||
expected = zeros(eltype(l.weight),7,5,1,1)
|
||||
expected[4,2:end-1,1,1] = l.weight
|
||||
@test expected ≈ l(data)
|
||||
|
||||
@test begin
|
||||
# we test that the next expression does not throw
|
||||
randn(Float32, 10,10,1,1) |> Conv((6,1), 1=>1, Flux.σ)
|
||||
true
|
||||
end
|
||||
end
|
||||
|
||||
@testset "conv output dimensions" begin
|
||||
m = Conv((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
m = Conv((3, 3), 3 => 16; stride = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
|
||||
m = ConvTranspose((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (8, 8)) == (10, 10)
|
||||
m = ConvTranspose((3, 3), 3 => 16; stride = 2)
|
||||
@test Flux.outdims(m, (2, 2)) == (5, 5)
|
||||
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (4, 4)) == (5, 5)
|
||||
|
||||
m = DepthwiseConv((3, 3), 3 => 6)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
m = DepthwiseConv((3, 3), 3 => 6; stride = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
|
||||
m = CrossCor((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
m = CrossCor((3, 3), 3 => 16; stride = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
|
||||
m = MaxPool((2, 2))
|
||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||
m = MaxPool((2, 2); stride = 1)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
m = MaxPool((2, 2); stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
|
||||
m = MeanPool((2, 2))
|
||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||
m = MeanPool((2, 2); stride = 1)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
m = MeanPool((2, 2); stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
end
|
||||
|
||||
@testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8))
|
||||
data = ones(Float32, (k .+ 3)..., 1,1)
|
||||
l = ltype(k, 1=>1, pad=SamePad())
|
||||
@test size(l(data)) == size(data)
|
||||
|
||||
l = ltype(k, 1=>1, pad=SamePad(), dilation = k .÷ 2)
|
||||
@test size(l(data)) == size(data)
|
||||
|
||||
stride = 3
|
||||
l = ltype(k, 1=>1, pad=SamePad(), stride = stride)
|
||||
if ltype == ConvTranspose
|
||||
@test size(l(data))[1:end-2] == stride .* size(data)[1:end-2] .- stride .+ 1
|
||||
else
|
||||
@test size(l(data))[1:end-2] == ceil.(Int, size(data)[1:end-2] ./ stride)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), k in ( (1,), (2,), (3,), (4,5), (6,7,8))
|
||||
data = ones(Float32, (k .+ 3)..., 1,1)
|
||||
|
||||
l = ltype(k, pad=SamePad())
|
||||
@test size(l(data))[1:end-2] == ceil.(Int, size(data)[1:end-2] ./ k)
|
||||
end
|
||||
|
@ -1,58 +1,46 @@
|
||||
using Flux, Test, Statistics
|
||||
using Zygote: pullback
|
||||
|
||||
evalwgrad(f, x...) = pullback(f, x...)[1]
|
||||
using Flux: testmode!
|
||||
using Flux.Tracker: data
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
@test x == Dropout(0.1)(x)
|
||||
@test x == evalwgrad(Dropout(0), x)
|
||||
@test zero(x) == evalwgrad(Dropout(1), x)
|
||||
@test x == testmode!(Dropout(0.1))(x)
|
||||
@test x == Dropout(0)(x)
|
||||
@test zero(x) == Dropout(1)(x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
y = evalwgrad(m, x)
|
||||
y = m(x)
|
||||
@test count(a->a==0, y) > 50
|
||||
testmode!(m, true)
|
||||
y = evalwgrad(m, x) # should override istraining
|
||||
testmode!(m)
|
||||
y = m(x)
|
||||
@test count(a->a==0, y) == 0
|
||||
testmode!(m, false)
|
||||
y = evalwgrad(m, x)
|
||||
y = m(x)
|
||||
@test count(a->a==0, y) > 50
|
||||
|
||||
x = rand(Float32, 100)
|
||||
x = rand(100)
|
||||
m = Chain(Dense(100,100),
|
||||
Dropout(0.9))
|
||||
y = evalwgrad(m, x)
|
||||
y = m(x)
|
||||
@test count(a->a == 0, y) > 50
|
||||
testmode!(m, true)
|
||||
y = evalwgrad(m, x) # should override istraining
|
||||
testmode!(m)
|
||||
y = m(x)
|
||||
@test count(a->a == 0, y) == 0
|
||||
|
||||
x = rand(100, 50)
|
||||
m = Dropout(0.5, dims = 2)
|
||||
y = m(x)
|
||||
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
|
||||
@test minimum(c) == maximum(c)
|
||||
m = Dropout(0.5, dims = 1)
|
||||
y = m(x)
|
||||
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
|
||||
@test minimum(c) == maximum(c)
|
||||
end
|
||||
|
||||
@testset "BatchNorm" begin
|
||||
let m = BatchNorm(2), x = [1.0 3.0 5.0;
|
||||
2.0 4.0 6.0]
|
||||
let m = BatchNorm(2), x = param([1 3 5;
|
||||
2 4 6])
|
||||
|
||||
@test length(params(m)) == 2
|
||||
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
# initial m.σ is 1
|
||||
# initial m.μ is 0
|
||||
@test m.active
|
||||
|
||||
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
||||
m(x)
|
||||
|
||||
y = evalwgrad(m, x)
|
||||
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
|
||||
# julia> x
|
||||
# 2×3 Array{Float64,2}:
|
||||
# 1.0 3.0 5.0
|
||||
@ -71,226 +59,43 @@ end
|
||||
# 2×1 Array{Float64,2}:
|
||||
# 1.3
|
||||
# 1.3
|
||||
@test m.σ² ≈ .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
@test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
|
||||
x′ = m(x)
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
|
||||
end
|
||||
|
||||
# with activation function
|
||||
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
|
||||
2.0 4.0 6.0]
|
||||
y = m(x)
|
||||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||
let m = BatchNorm(2, sigmoid), x = param([1 3 5;
|
||||
2 4 6])
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
end
|
||||
|
||||
@testset "InstanceNorm" begin
|
||||
# helper functions
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
# begin tests
|
||||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
|
||||
@test length(params(m)) == 2
|
||||
x = Float64.(x)
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
y = evalwgrad(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0
|
||||
# 2.0 5.0
|
||||
# 3.0 6.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 7.0 10.0
|
||||
# 8.0 11.0
|
||||
# 9.0 12.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3.) / 3 = 2.
|
||||
# (4. + 5. + 6.) / 3 = 5.
|
||||
#
|
||||
# (7. + 8. + 9.) / 3 = 8.
|
||||
# (10. + 11. + 12.) / 3 = 11.
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||
@test m.μ ≈ [0.5, 0.8]
|
||||
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||
# julia> reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.
|
||||
# 1.
|
||||
@test m.σ² ≈ reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
|
||||
x′ = m(x)
|
||||
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
x = Float64.(x)
|
||||
affine_shape = collect(sizes)
|
||||
affine_shape[1] = 1
|
||||
|
||||
y = m(x)
|
||||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
y = evalwgrad(m, x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
if VERSION >= v"1.1"
|
||||
@testset "GroupNorm" begin
|
||||
# begin tests
|
||||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
||||
|
||||
let m = GroupNorm(4,2), sizes = (3,4,2),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
|
||||
@test length(params(m)) == 2
|
||||
x = Float64.(x)
|
||||
@test m.β == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
||||
|
||||
y = evalwgrad(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0 7.0 10.0
|
||||
# 2.0 5.0 8.0 11.0
|
||||
# 3.0 6.0 9.0 12.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 13.0 16.0 19.0 22.0
|
||||
# 14.0 17.0 20.0 23.0
|
||||
# 15.0 18.0 21.0 24.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
|
||||
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
|
||||
#
|
||||
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
|
||||
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
|
||||
#
|
||||
# μ =
|
||||
# 3.5 15.5
|
||||
# 9.5 21.5
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
|
||||
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
|
||||
@test m.μ ≈ [0.95, 1.55]
|
||||
|
||||
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.25
|
||||
# 1.25
|
||||
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
||||
|
||||
x′ = m(x)
|
||||
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
x = Float64.(x)
|
||||
μ_affine_shape = ones(Int,length(sizes) + 1)
|
||||
μ_affine_shape[end-1] = 2 # Number of groups
|
||||
|
||||
affine_shape = ones(Int,length(sizes) + 1)
|
||||
affine_shape[end-2] = 2 # Channels per group
|
||||
affine_shape[end-1] = 2 # Number of groups
|
||||
affine_shape[1] = sizes[1]
|
||||
affine_shape[end] = sizes[end]
|
||||
|
||||
og_shape = size(x)
|
||||
|
||||
y = m(x)
|
||||
x_ = reshape(x,affine_shape...)
|
||||
out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape)
|
||||
@test isapprox(y, out, atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = evalwgrad(m, x)
|
||||
@test size(m.μ) == (m.G,1)
|
||||
@test size(m.σ²) == (m.G,1)
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||
let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -1,26 +1,9 @@
|
||||
using Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy, flatten,
|
||||
xlogx, xlogy
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
|
||||
const ϵ = 1e-7
|
||||
|
||||
@testset "xlogx & xlogy" begin
|
||||
@test iszero(xlogx(0))
|
||||
@test isnan(xlogx(NaN))
|
||||
@test xlogx(2) ≈ 2.0 * log(2.0)
|
||||
@inferred xlogx(2)
|
||||
@inferred xlogx(0)
|
||||
|
||||
@test iszero(xlogy(0, 1))
|
||||
@test isnan(xlogy(NaN, 1))
|
||||
@test isnan(xlogy(1, NaN))
|
||||
@test isnan(xlogy(NaN, NaN))
|
||||
@test xlogy(2, 3) ≈ 2.0 * log(3.0)
|
||||
@inferred xlogy(2, 3)
|
||||
@inferred xlogy(0, 1)
|
||||
end
|
||||
|
||||
@testset "losses" begin
|
||||
# First, regression-style y's
|
||||
y = [1, 1, 0, 0]
|
||||
@ -30,20 +13,6 @@ end
|
||||
@test mse(ŷ, y) ≈ (.1^2 + .9^2)/2
|
||||
end
|
||||
|
||||
@testset "mae" begin
|
||||
@test Flux.mae(ŷ, y) ≈ 1/2
|
||||
end
|
||||
|
||||
@testset "huber_loss" begin
|
||||
@test Flux.huber_loss(ŷ, y) ≈ 0.20500000000000002
|
||||
end
|
||||
|
||||
y = [123.0,456.0,789.0]
|
||||
ŷ = [345.0,332.0,789.0]
|
||||
@testset "msle" begin
|
||||
@test Flux.msle(ŷ, y) ≈ 0.38813985859136585
|
||||
end
|
||||
|
||||
# Now onehot y's
|
||||
y = onehotbatch([1, 1, 0, 0], 0:1)
|
||||
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
||||
@ -52,7 +21,6 @@ end
|
||||
lossvalue = 1.203972804325936
|
||||
|
||||
@testset "crossentropy" begin
|
||||
@test crossentropy([0.1,0.0,0.9], [0.1,0.0,0.9]) ≈ crossentropy([0.1,0.9], [0.1,0.9])
|
||||
@test crossentropy(ŷ, y) ≈ lossvalue
|
||||
end
|
||||
|
||||
@ -81,64 +49,4 @@ end
|
||||
@testset "logitbinarycrossentropy" begin
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||
end
|
||||
|
||||
y = [1 2 3]
|
||||
ŷ = [4.0 5.0 6.0]
|
||||
@testset "kldivergence" begin
|
||||
@test Flux.kldivergence([0.1,0.0,0.9], [0.1,0.0,0.9]) ≈ Flux.kldivergence([0.1,0.9], [0.1,0.9])
|
||||
@test Flux.kldivergence(ŷ, y) ≈ -1.7661057888493457
|
||||
@test Flux.kldivergence(y, y) ≈ 0
|
||||
end
|
||||
|
||||
y = [1 2 3 4]
|
||||
ŷ = [5.0 6.0 7.0 8.0]
|
||||
@testset "hinge" begin
|
||||
@test Flux.hinge(ŷ, y) ≈ 0
|
||||
@test Flux.hinge(y, 0.5 .* y) ≈ 0.125
|
||||
end
|
||||
|
||||
@testset "squared_hinge" begin
|
||||
@test Flux.squared_hinge(ŷ, y) ≈ 0
|
||||
@test Flux.squared_hinge(y, 0.5 .* y) ≈ 0.0625
|
||||
end
|
||||
|
||||
y = [0.1 0.2 0.3]
|
||||
ŷ = [0.4 0.5 0.6]
|
||||
@testset "poisson" begin
|
||||
@test Flux.poisson(ŷ, y) ≈ 0.6278353988097339
|
||||
@test Flux.poisson(y, y) ≈ 0.5044459776946685
|
||||
end
|
||||
|
||||
y = [1.0 0.5 0.3 2.4]
|
||||
ŷ = [0 1.4 0.5 1.2]
|
||||
@testset "dice_coeff_loss" begin
|
||||
@test Flux.dice_coeff_loss(ŷ, y) ≈ 0.2799999999999999
|
||||
@test Flux.dice_coeff_loss(y, y) ≈ 0.0
|
||||
end
|
||||
|
||||
@testset "tversky_loss" begin
|
||||
@test Flux.tversky_loss(ŷ, y) ≈ -0.06772009029345383
|
||||
@test Flux.tversky_loss(ŷ, y, β = 0.8) ≈ -0.09490740740740744
|
||||
@test Flux.tversky_loss(y, y) ≈ -0.5576923076923075
|
||||
end
|
||||
|
||||
@testset "no spurious promotions" begin
|
||||
for T in (Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,
|
||||
Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss)
|
||||
fwd, back = Flux.pullback(f, ŷ, y)
|
||||
@test fwd isa T
|
||||
@test eltype(back(one(T))[1]) == T
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@testset "helpers" begin
|
||||
@testset "flatten" begin
|
||||
x = randn(Float32, 10, 10, 3, 2)
|
||||
@test size(flatten(x)) == (300, 2)
|
||||
end
|
||||
end
|
||||
|
@ -11,9 +11,3 @@ using Test
|
||||
@test onecold(a, labels) == 'C'
|
||||
@test onecold(A, labels) == ['C', 'A', 'D']
|
||||
end
|
||||
|
||||
@testset "onehotbatch indexing" begin
|
||||
y = Flux.onehotbatch(ones(3), 1:10)
|
||||
@test y[:,1] isa Flux.OneHotVector
|
||||
@test y[:,:] isa Flux.OneHotMatrix
|
||||
end
|
||||
|
@ -1,44 +1,48 @@
|
||||
using Flux.Optimise
|
||||
using Flux.Optimise: runall
|
||||
using Flux: Params, gradient
|
||||
using Flux.Tracker
|
||||
using Test
|
||||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
Momentum()]
|
||||
w′ = randn(10, 10)
|
||||
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
for t = 1: 10^5
|
||||
θ = Params([w′])
|
||||
x = rand(10)
|
||||
θ̄ = gradient(() -> loss(x), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
opt = Opt(0.001)
|
||||
if opt isa Descent || opt isa ADAGrad
|
||||
opt = Opt(0.1)
|
||||
end
|
||||
@test loss(rand(10, 10)) < 0.01
|
||||
if opt isa ADADelta
|
||||
opt = Opt(0.9)
|
||||
end
|
||||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Optimiser" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||
w′ = randn(10, 10)
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
for t = 1:10^5
|
||||
θ = Params([w′])
|
||||
x = rand(10)
|
||||
θ̄ = gradient(() -> loss(x), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test loss(rand(10, 10)) < 0.01
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Training Loop" begin
|
||||
i = 0
|
||||
l = 1
|
||||
l = param(1)
|
||||
|
||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
(),
|
||||
@ -55,59 +59,3 @@ end
|
||||
cbs()
|
||||
@test x == 1
|
||||
end
|
||||
|
||||
@testset "ExpDecay" begin
|
||||
|
||||
@testset "Sanity Check" begin
|
||||
o = ExpDecay(0.2, 0.5, 1, 1e-3)
|
||||
p = [0.0]
|
||||
steps = 1:8
|
||||
eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip)
|
||||
eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps]
|
||||
@test eta_actual == eta_expected
|
||||
end
|
||||
|
||||
w = randn(10, 10)
|
||||
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
|
||||
w1 = randn(10,10)
|
||||
loss(x) = Flux.mse(w*x, w1*x)
|
||||
flag = 1
|
||||
decay_steps = []
|
||||
for t = 1:10^5
|
||||
prev_eta = o.eta
|
||||
θ = Params([w1])
|
||||
x = rand(10)
|
||||
θ̄ = gradient(() -> loss(x), θ)
|
||||
prev_grad = collect(θ̄[w1])
|
||||
delta = Optimise.apply!(o, w1, θ̄[w1])
|
||||
w1 .-= delta
|
||||
new_eta = o.eta
|
||||
if new_eta != prev_eta
|
||||
push!(decay_steps, t)
|
||||
end
|
||||
array = fill(o.eta, size(prev_grad))
|
||||
if array .* prev_grad != delta
|
||||
flag = 0
|
||||
end
|
||||
end
|
||||
@test flag == 1
|
||||
# Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1).
|
||||
ground_truth = []
|
||||
for i in 1:4
|
||||
push!(ground_truth, 1000*i) # Expected decay steps for this example.
|
||||
end
|
||||
@test decay_steps == ground_truth
|
||||
@test o.eta == o.clip
|
||||
end
|
||||
|
||||
@testset "Clipping" begin
|
||||
w = randn(10, 10)
|
||||
loss(x) = sum(w * x)
|
||||
θ = Params([w])
|
||||
x = 1000 * randn(10)
|
||||
w̄ = gradient(() -> loss(x), θ)[w]
|
||||
w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄))
|
||||
@test all(w̄_value .<= 1)
|
||||
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄))
|
||||
@test norm(w̄_norm) <= 1
|
||||
end
|
@ -1,46 +1,48 @@
|
||||
using Flux
|
||||
using Flux.Data
|
||||
using Test
|
||||
using Random, Statistics, LinearAlgebra
|
||||
using IterTools: ncycle
|
||||
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
|
||||
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
|
||||
if Base.JLOptions().check_bounds == 1
|
||||
file = @__FILE__
|
||||
run(```
|
||||
$(Base.julia_cmd())
|
||||
--color=$(Base.have_color ? "yes" : "no")
|
||||
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
|
||||
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
|
||||
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
|
||||
$(file)
|
||||
```)
|
||||
exit()
|
||||
end
|
||||
|
||||
using Flux, Test, Random, Statistics
|
||||
using Random
|
||||
|
||||
Random.seed!(0)
|
||||
|
||||
@testset "Utils" begin
|
||||
include("utils.jl")
|
||||
# So we can use the system CuArrays
|
||||
insert!(LOAD_PATH, 2, "@v#.#")
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
@info "Testing Basics"
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
|
||||
@info "Testing Layers"
|
||||
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
|
||||
@info "Running Gradient Checks"
|
||||
|
||||
include("tracker.jl")
|
||||
|
||||
if Base.find_package("CuArrays") != nothing
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
@testset "Onehot" begin
|
||||
include("onehot.jl")
|
||||
end
|
||||
|
||||
@testset "Optimise" begin
|
||||
include("optimise.jl")
|
||||
end
|
||||
|
||||
@testset "Data" begin
|
||||
include("data.jl")
|
||||
end
|
||||
|
||||
@testset "Layers" begin
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
end
|
||||
|
||||
@testset "CUDA" begin
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
end
|
||||
|
||||
@static if VERSION >= v"1.4"
|
||||
using Documenter
|
||||
@testset "Docs" begin
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
end
|
||||
end
|
||||
|
298
test/tracker.jl
Normal file
298
test/tracker.jl
Normal file
@ -0,0 +1,298 @@
|
||||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
|
||||
using NNlib: conv, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||
@testset "Tracker" begin
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@testset "indexing & slicing" begin
|
||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||
end
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
r2 = f(A, param(B), C)
|
||||
r3 = f(A, B, param(C))
|
||||
r4 = f(param(A), param(B), param(C))
|
||||
|
||||
@test !isa(r0, TrackedArray)
|
||||
@test all(isa.([r1,r2,r3,r4], TrackedArray))
|
||||
@test r1 == r2 == r3 == r4
|
||||
@test r0 == Flux.data(r4)
|
||||
end
|
||||
|
||||
@testset "concat" begin
|
||||
cat1(x...) = cat(x..., dims = 1)
|
||||
cat2(x...) = cat(x..., dims = 2)
|
||||
|
||||
@testset for vcatf in [vcat, cat1]
|
||||
@test gradtest(vcatf, rand(5), rand(3))
|
||||
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcatf, rand(5)', rand(5)')
|
||||
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
||||
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
||||
@test gradtest(vcatf, rand(5), rand(3,1))
|
||||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||
end
|
||||
|
||||
|
||||
@testset for hcatf in [hcat, cat2]
|
||||
@test gradtest(hcatf, rand(5), rand(5))
|
||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
||||
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
||||
@test gradtest(hcatf, rand(5)', rand(1,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
@test gradtest(catf, rand(5))
|
||||
@test gradtest(catf, rand(5)')
|
||||
@test gradtest(catf, rand(2,5))
|
||||
@test gradtest(catf, rand(2,5,3))
|
||||
end
|
||||
|
||||
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 3:5
|
||||
catdim = (x...) -> cat(x..., dims = dim)
|
||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||
end
|
||||
|
||||
@test !isa(vcat(rand(2)), TrackedArray)
|
||||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
|
||||
end
|
||||
|
||||
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
||||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
@test gradtest(kron, rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@test gradtest(x -> diagm(0 => x), rand(3))
|
||||
|
||||
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
||||
@test gradtest(norm, rand(5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
||||
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
||||
|
||||
@testset "equality & order" begin
|
||||
# TrackedReal
|
||||
@test param(2)^2 == param(4)
|
||||
@test param(2)^2 == 4
|
||||
@test 4 == param(2)^2
|
||||
|
||||
@test param(2)^2 ≈ param(4)
|
||||
@test param(2)^2 ≈ 4
|
||||
@test 4 ≈ param(2)^2
|
||||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
@test (param([1,2,3]) .<= 2) == [true, true, false]
|
||||
@test (2 .> param([1,2,3])) == [true, false, false]
|
||||
@test (2 .>= param([1,2,3])) == [true, true, false]
|
||||
|
||||
# TrackedArray
|
||||
@test param([1,2,3]).^2 == param([1,4,9])
|
||||
@test [1,2,3].^2 == param([1,4,9])
|
||||
@test param([1,2,3]).^2 == [1,4,9]
|
||||
|
||||
@test param([1,2,3]).^2 ≈ param([1,4,9])
|
||||
@test [1,2,3].^2 ≈ param([1,4,9])
|
||||
@test param([1,2,3]).^2 ≈ [1,4,9]
|
||||
end
|
||||
|
||||
@testset "reshape" begin
|
||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (4,2)
|
||||
x = reshape(param([1]), (1,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,1)
|
||||
x = reshape(param(rand(2)), (2,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (2,1)
|
||||
x = reshape(param(rand(2,2)), (1,:,2))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,2,2)
|
||||
end
|
||||
|
||||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
x.grad .= 0
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
end
|
||||
|
||||
@testset "Fallbacks" begin
|
||||
xs = param([1 2; 3 4])
|
||||
@test similar(xs) isa Matrix{Float64}
|
||||
end
|
||||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
@test Tracker.grad(b) == 1
|
||||
|
||||
@testset "collect" begin
|
||||
x, y = param(2), param(3)
|
||||
xy = Tracker.collect([x, y])
|
||||
@test xy isa TrackedArray{Float64}
|
||||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
|
||||
@test Tracker.gradient(2, 3) do x, y
|
||||
xy = Tracker.collect([x, y])
|
||||
xy[1]*xy[2]
|
||||
end == (3, 2)
|
||||
end
|
||||
|
||||
# Gradient Hooks
|
||||
@testset "Hooks" begin
|
||||
x = param(2)
|
||||
y = Tracker.hook(-, x)
|
||||
back!(y)
|
||||
@test grad(x) == -1
|
||||
end
|
||||
|
||||
@testset "Checkpointing" begin
|
||||
count = 0
|
||||
function mul(a, b)
|
||||
count += 1
|
||||
a * b
|
||||
end
|
||||
@test gradient(x -> mul(5, x), 3)[1] == 5
|
||||
@test count == 1
|
||||
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
|
||||
@test count == 3
|
||||
end
|
||||
|
||||
@testset "Updates" begin
|
||||
xs = param([1, 2, 3])
|
||||
Tracker.update!(xs, param([4, 5, 6]))
|
||||
@test xs == [5, 7, 9]
|
||||
x = param(3)
|
||||
Tracker.update!(x, param(4))
|
||||
@test x == 7
|
||||
end
|
||||
|
||||
end #testset
|
@ -1,6 +1,6 @@
|
||||
using Flux
|
||||
using Flux: throttle, nfan, glorot_uniform, glorot_normal, stack, unstack
|
||||
using StatsBase: var
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal
|
||||
using StatsBase: std
|
||||
using Random
|
||||
using Test
|
||||
|
||||
@ -52,30 +52,31 @@ using Test
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Jacobian" begin
|
||||
A = param(randn(2,2))
|
||||
x = randn(2)
|
||||
m(x) = A*x
|
||||
y = m(x)
|
||||
J = jacobian(m,x)
|
||||
@test J ≈ A.data
|
||||
end
|
||||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
Random.seed!(0)
|
||||
|
||||
@testset "Fan in/out" begin
|
||||
@test nfan() == (1, 1) #For a constant
|
||||
@test nfan(100) == (1, 100) #For vector
|
||||
@test nfan(100, 200) == (200, 100) #For Dense layer
|
||||
@test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer
|
||||
@test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer
|
||||
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
|
||||
end
|
||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||
for (n_in, n_out) in [(100, 100), (100, 400)]
|
||||
v = glorot_uniform(n_in, n_out)
|
||||
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
||||
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
||||
|
||||
@testset "glorot" begin
|
||||
# glorot_uniform and glorot_normal should both yield a kernel with
|
||||
# variance ≈ 2/(fan_in + fan_out)
|
||||
for dims ∈ [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
|
||||
for init ∈ [glorot_uniform, glorot_normal]
|
||||
v = init(dims...)
|
||||
fan_in, fan_out = nfan(dims...)
|
||||
σ2 = 2 / (fan_in + fan_out)
|
||||
@test 0.9σ2 < var(v) < 1.1σ2
|
||||
end
|
||||
end
|
||||
v = glorot_normal(n_in, n_out)
|
||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
end
|
||||
end
|
||||
|
||||
@ -84,37 +85,4 @@ end
|
||||
@test size.(params(m)) == [(5, 10), (5,)]
|
||||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Layer duplicated in same chain, params just once pls.
|
||||
c = Chain(m, m)
|
||||
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Self-referential array. Just want params, no stack overflow pls.
|
||||
r = Any[nothing,m]
|
||||
r[1] = r
|
||||
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
||||
@testset "Basic Stacking" begin
|
||||
x = randn(3,3)
|
||||
stacked = stack([x, x], 2)
|
||||
@test size(stacked) == (3,2,3)
|
||||
end
|
||||
|
||||
@testset "Precision" begin
|
||||
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||
x = rand(10)
|
||||
@test eltype(m[1].W) == Float32
|
||||
@test eltype(m(x)) == Float32
|
||||
@test eltype(f64(m)(x)) == Float64
|
||||
@test eltype(f64(m)[1].W) == Float64
|
||||
@test eltype(f32(f64(m))[1].W) == Float32
|
||||
end
|
||||
|
||||
@testset "Stacking" begin
|
||||
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
|
||||
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
|
||||
@test unstack(stacked_array, 2) == unstacked_array
|
||||
@test stack(unstacked_array, 2) == stacked_array
|
||||
@test stack(unstack(stacked_array, 1), 1) == stacked_array
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user