Compare commits

..

6 Commits

Author SHA1 Message Date
CarloLucibello f8c8bb4e35 change train 2020-05-04 20:29:19 +02:00
CarloLucibello c1f0c29026 update project 2020-05-04 20:27:08 +02:00
CarloLucibello 14e7181c7c fix type instability 2020-05-04 20:27:08 +02:00
CarloLucibello 89191bdeb1 cleanup 2020-05-04 20:27:08 +02:00
CarloLucibello c6ba49e8ea remove multi-arg constructor 2020-05-04 20:27:08 +02:00
CarloLucibello d77dbc4931 extend dataloader 2020-05-04 20:27:08 +02:00
30 changed files with 146 additions and 774 deletions

View File

@ -1,12 +0,0 @@
[Please delete this text and describe your change here.
For bugfixes, please detail the bug and include a test case which your patch fixes.
If you are adding a new feature, please clearly describe the design, its rationale, the possible alternatives considered.
It is easiest to merge new features when there is clear precedent in other systems; we need to know we're taking
the right direction since it can be hard to change later.]
### PR Checklist
- [ ] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable
- [ ] Final review from `@MikeInnes` or `@dhairyagandhi96` (for API changes).

View File

@ -6,8 +6,16 @@ on:
jobs:
CompatHelper:
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.3]
julia-arch: [x64]
os: [ubuntu-latest]
steps:
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()

1
.gitignore vendored
View File

@ -4,3 +4,4 @@
docs/build/
docs/site/
deps
Manifest.toml

View File

@ -1,387 +0,0 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "0.5.0"
[[AbstractTrees]]
deps = ["Markdown"]
git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.3.3"
[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "fd04049c7dd78cfef0b06cdc1f0f181467655712"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.1.0"
[[ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra"]
git-tree-sha1 = "a504dca2ac7eda8761c8f7c1ed52427a1be75a3c"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "0.2.6"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinaryProvider]]
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.10"
[[CEnum]]
git-tree-sha1 = "1b77a77c3b28e0b3f413f7567c9bb8dd9bdccd14"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.3.0"
[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "831b825d10104bd29e28f6da93312a976830717b"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "4.0.0"
[[CUDAdrv]]
deps = ["CEnum", "CUDAapi", "Printf"]
git-tree-sha1 = "f56bbf18c86bcff7a961a32a4947a5abb2963a29"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "6.3.0"
[[CUDAnative]]
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "ExprTools", "GPUCompiler", "LLVM", "Libdl", "Pkg", "Printf"]
git-tree-sha1 = "ac86db2b05fdfec96b011e25a504ffe7476e8a68"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "3.1.0"
[[CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "cab4da992adc0a64f63fa30d2db2fd8bec40cab4"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "0.5.11"
[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "c73d9cfc2a9d8433dc77f5bff4bddf46b1d78c20"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.10.3"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
git-tree-sha1 = "1e9bba7984e78aa8cdeea7f9f7cc984ad4e4b1c7"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.12.2"
[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[CompilerSupportLibraries_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612"
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.3+0"
[[Cthulhu]]
deps = ["CodeTracking", "InteractiveUtils", "REPL", "UUIDs", "Unicode"]
git-tree-sha1 = "f3643e78353199d3097821e806348bd83f364155"
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
version = "1.1.1"
[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
git-tree-sha1 = "1582b74d2322df7dd94549d4ac9d095e0f20e884"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "2.2.1"
[[DataAPI]]
git-tree-sha1 = "176e23402d80e7743fc26c19c681bfb11246af32"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.3.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "af6d9c86e191c917c2276fbede1137e8ea20157f"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.17"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.0.2"
[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.0.1"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[ExprTools]]
git-tree-sha1 = "6f0517056812fd6aa3af23d4b70d5325a2ae4e95"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.1"
[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "44f561e293987ffc84272cd3d2b14b0b93123d63"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.8.10"
[[FixedPointNumbers]]
git-tree-sha1 = "3ba9ea634d4c8b289d590403b4a06f8e227a6238"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.0"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.10"
[[Functors]]
deps = ["MacroTools"]
git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.1.0"
[[Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "d887693eb1bd5e1fd573262a978745481895ec7d"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "3.4.1"
[[GPUCompiler]]
deps = ["Cthulhu", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs"]
git-tree-sha1 = "5275aa268ecd09640b32560e1eae90c78816e4d1"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.2.0"
[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "90ee39f9beaaa186e4968417ea2b8ed5673c91c0"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.3.3"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile"]
git-tree-sha1 = "a686b0cf235fa3e491b79b4783c2d2382292b436"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.8.2"
[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "dd3f584c3dbefe39b2a8fbafa1a3b77e31e21255"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.5.1"
[[LibGit2]]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.5"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.3"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.6"
[[NaNMath]]
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.3"
[[OpenSpecFun_jll]]
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+3"
[[OrderedCollections]]
git-tree-sha1 = "12ce190210d278e12644bcadf5b21cbdcf225cd3"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.2.0"
[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.0.1"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["OpenSpecFun_jll"]
git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.10.3"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.12.3"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "a6102b1f364befdb05746f386b67c6b7e3262c45"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "f458ca23ff80e46a630922c555d838303e4b9603"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.6"
[[TranscodingStreams]]
deps = ["Random", "Test"]
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.5"
[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["Libdl", "Printf", "Zlib_jll"]
git-tree-sha1 = "254975fef2fc526583bb9b7c9420fe66ffe09f2f"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.9.2"
[[Zlib_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "a2e0d558f6031002e380a90613b199e37a8565bf"
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.11+10"
[[Zygote]]
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "707ceea58e2bd0ff3077ab13a92f8355181d3ee4"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.4.20"
[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.0"

13
NEWS.md
View File

@ -1,18 +1,5 @@
# v0.11
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].
# v0.10.5
* Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`.
* Added option to set `bias` to [Flux.Zeros](https://github.com/FluxML/Flux.jl/pull/873) to eliminating `bias` from being trained.
* Added `GlobalMaxPool` and `GlobalMeanPool` [layers](https://github.com/FluxML/Flux.jl/pull/950) for performing global pooling operations.
* Added `ClipValue` and `ClipNorm` in this [pr](https://github.com/FluxML/Flux.jl/pull/1133) to `Flux.Optimise` to provide a cleaner API for gradient clipping.
* Added new kwarg-only [constructors](https://github.com/FluxML/Flux.jl/pull/873) for the various convolutional layers.
* Documented the convolutional layer constructors accepting `weight` and `bias` keyword arguments to supply custom arrays for those fields.
* Testing suite improvements now test for gradients of all layers along with GPU support.
* Functors have now moved to [Functors.jl](https://github.com/FluxML/Flux.jl/pull/1174) to allow for their use outside of Flux.
* Added [helper functions](https://github.com/FluxML/Flux.jl/pull/873) `Flux.convfilter` and `Flux.depthwiseconvfilter` to construct weight arrays for convolutions outside of layer constructors so as to not have to depend on the default layers for custom implementations.
# v0.10.0
* The default AD engine has switched from [Tracker to Zygote.jl](https://github.com/FluxML/Flux.jl/pull/669)

View File

@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.11.0-DEV"
version = "0.11.0"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@ -9,9 +9,7 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@ -27,11 +25,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
AbstractTrees = "0.2, 0.3"
Adapt = "1, 2.0"
Adapt = "1"
CodecZlib = "0.5, 0.6, 0.7"
Colors = "0.8, 0.9, 0.10, 0.11, 0.12"
CuArrays = "2"
Functors = "0.1"
Juno = "0.5, 0.6, 0.7, 0.8"
MacroTools = "0.3, 0.4, 0.5"
NNlib = "0.6"

View File

@ -7,15 +7,15 @@ julia> using Flux: onehot, onecold
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
0
1
0
false
true
false
julia> onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
0
0
1
false
false
true
```
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).

View File

@ -19,7 +19,7 @@ Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0
julia> Flux.params(a) # default behavior
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
julia> Flux.trainable(a::Affine) = (a.W,)
julia> Flux.trainable(a::Affine) = (a.W, a.b,)
julia> Flux.params(a)
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])

View File

@ -32,6 +32,8 @@ julia> gradient(f, [2, 1], [2, 0])
But machine learning models can have *hundreds* of parameters! To handle this, Flux lets you work with collections of parameters, via `params`. You can get the gradient of all parameters used in a program without explicitly passing them in.
```jldoctest basics
julia> using Flux
julia> x = [2, 1];
julia> y = [2, 0];

View File

@ -20,11 +20,7 @@ GlobalMeanPool
DepthwiseConv
ConvTranspose
CrossCor
SamePad
flatten
Flux.Zeros
Flux.convfilter
Flux.depthwiseconvfilter
```
## Recurrent Layers

View File

@ -39,7 +39,7 @@ E.g. the following will have run into the same problem as above:
leaky_tanh(x) = 0.01*x + tanh(x)
```
While one could change the activation function (e.g. to use `0.01f0*x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
While one could change the activation function (e.g. to use `0.01f0x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
```
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
```

View File

@ -140,16 +140,3 @@ ExpDecay
InvDecay
WeightDecay
```
## Gradient Clipping
Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is
```julia
opt = Optimiser(ClipValue(1e-3), ADAM(1e-3))
```
```@docs
ClipValue
ClipNorm
```

View File

@ -142,7 +142,7 @@ function my_custom_train!(loss, ps, data, opt)
for d in data
gs = gradient(ps) do
training_loss = loss(d...)
# Insert whatever code you want here that needs Training loss, e.g. logging
# Insert what ever code you want here that needs Training loss, e.g. logging
return training_loss
end
# insert what ever code you want here that needs gradient

View File

@ -3,8 +3,7 @@ module Flux
# Zero Flux Given
using Base: tail
using Statistics, Random, LinearAlgebra
using Zygote, MacroTools, Juno, Reexport
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd
@ -21,8 +20,7 @@ using .Optimise
using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
ClipValue, ClipNorm
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
using CuArrays

View File

@ -16,8 +16,8 @@ end
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
The last dimension in each tensor is considered to be the observation dimension.
Takes as input a data tensors or a tuple of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
@ -57,13 +57,6 @@ Usage example:
# train for 10 epochs
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
# can use NamedTuple to name tensors
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
for datum in train_loader
@assert size(datum.images) == (10, 2)
@assert size(datum.labels) == (2,)
end
"""
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
@ -74,6 +67,7 @@ function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize = n
end
imax = partial ? n : n - batchsize + 1
ids = 1:min(n, batchsize)
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
end
@ -95,16 +89,19 @@ end
_nobs(data::AbstractArray) = size(data)[end]
function _nobs(data::Union{Tuple, NamedTuple})
function _nobs(data::Tuple)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, Base.tail(data))
if !all(x -> _nobs(x) == n, data[2:end])
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end
_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
end
Base.eltype(::DataLoader{D}) where D = D
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
Base.eltype(d::DataLoader{D}) where D = D

View File

@ -1,6 +1,41 @@
import Adapt: adapt, adapt_storage
using Zygote: IdSet
import Functors: @functor, functor, fmap
functor(x) = (), _ -> x
functor(x::Tuple) = x, y -> y
functor(x::NamedTuple) = x, y -> y
functor(x::AbstractArray) = x, y -> y
functor(x::AbstractArray{<:Number}) = (), _ -> x
function makefunctor(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
end
end
function functorm(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end
macro functor(args...)
functorm(args...)
end
isleaf(x) = functor(x)[1] === ()
function fmap1(f, x)
func, re = functor(x)
re(map(f, func))
end
function fmap(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
end
trainable(m) = functor(m)[1]
@ -24,7 +59,7 @@ testmode!(m, mode = true) = m
trainmode!(m, mode = true)
Set a layer of model's train mode (see below).
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)`).
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
_Note_: if you manually set a model into train mode, you need to manually place
it into test mode during testing phase.

View File

@ -30,7 +30,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
functor(c::Chain) = c.layers, ls -> Chain(ls...)
applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
@ -102,7 +102,7 @@ julia> d(rand(5))
-0.16210233
0.12311903```
"""
struct Dense{F,S<:AbstractArray,T<:AbstractArray}
struct Dense{F,S,T}
W::S
b::T
σ::F

View File

@ -132,7 +132,7 @@ end
function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, ntuple(_->1, length(c.stride))..., :, 1)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(conv(x, c.weight, cdims) .+ b)
end
@ -222,7 +222,7 @@ end
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
ConvTranspose(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
end

View File

@ -46,24 +46,23 @@ given the prediction `ŷ` and true values `y`.
Huber loss = |
| δ * (| - y| - 0.5 * δ), otherwise
"""
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
function huber_loss(, y; δ=eltype()(1))
abs_error = abs.( .- y)
temp = Zygote.dropgrad(abs_error .< δ)
temp = abs_error .< δ
x = eltype()(0.5)
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
end
function _crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing)
return -sum(xlogy.(y, )) * 1 // size(y, 2)
return -sum(y .* log.()) * 1 // size(y, 2)
end
function _crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat, weight::Number)
return -sum(xlogy.(y, )) .* weight * 1 // size(y, 2)
return -sum(y .* log.()) .* weight * 1 // size(y, 2)
end
function _crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat, weight::AbstractVector)
return -sum(xlogy.(y, ) .* weight) * 1 // size(y, 2)
return -sum(y .* log.() .* weight) * 1 // size(y, 2)
end
"""
@ -92,7 +91,7 @@ Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.crossentropy(softmax(ŷ), y)`](@ref) but it is more numerically stable.
[`Flux.crossentropy(softmax(log()), y)`](@ref) but it is more numerically stable.
See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
@ -124,7 +123,7 @@ julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
0.8616703662235441
```
"""
binarycrossentropy(, y; ϵ=eps()) = -xlogy(y, + ϵ) - xlogy(1 - y, 1 - + ϵ)
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
# Re-definition to fix interaction with CuArrays.
CuArrays.@cufunc binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
@ -133,7 +132,7 @@ CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1
logitbinarycrossentropy(ŷ, y)
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.binarycrossentropy(σ(ŷ), y)`](@ref) but it is more numerically stable.
[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
@ -196,7 +195,7 @@ It is always non-negative and zero only when both the distributions are equal
everywhere.
"""
function kldivergence(, y)
entropy = sum(xlogx.(y)) * 1 //size(y,2)
entropy = sum(y .* log.(y)) * 1 //size(y,2)
cross_entropy = crossentropy(, y)
return entropy + cross_entropy
end
@ -209,7 +208,7 @@ distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
"""
poisson(, y) = sum( .- xlogy.(y, )) * 1 // size(y,2)
poisson(, y) = sum( .- y .* log.()) * 1 // size(y,2)
"""
hinge(, y)
@ -263,34 +262,3 @@ by linearizing all values for each element in the batch.
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end
"""
xlogx(x)
Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit.
"""
function xlogx(x)
result = x * log(x)
ifelse(iszero(x), zero(result), result)
end
CuArrays.@cufunc function xlogx(x)
result = x * log(x)
ifelse(iszero(x), zero(result), result)
end
"""
xlogy(x, y)
Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
"""
function xlogy(x, y)
result = x * log(y)
ifelse(iszero(x), zero(result), result)
end
CuArrays.@cufunc function xlogy(x, y)
result = x * log(y)
ifelse(iszero(x), zero(result), result)
end
@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric)
res = xlogy.(x, y)
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
end

View File

@ -27,8 +27,7 @@ Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed
A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))]
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
@ -49,7 +48,7 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
Create a `OneHotVector` with its `l`-th element `true` based on the
possible set of `labels`.
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
in `labels`; otherwise, it will raise an error.
in `labels`; otherwise it will error.
# Examples
```jldoctest

View File

@ -1,12 +1,9 @@
module Optimise
using LinearAlgebra
export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser,
ClipValue, ClipNorm
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl")
include("train.jl")

View File

@ -509,7 +509,7 @@ function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay, o.clip)
η = max(η * decay^(s / n), o.clip)
o.eta = η
end
@. Δ *= η
@ -533,31 +533,3 @@ function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end
"""
ClipValue(thresh)
Clip gradients when their absolute value exceeds `thresh`.
"""
mutable struct ClipValue{T}
thresh::T
end
apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
"""
ClipNorm(thresh)
Clip gradients when their L2 norm exceeds `thresh`.
"""
mutable struct ClipNorm{T}
thresh::T
end
function apply!(o::ClipNorm, x, Δ)
Δnrm = norm(Δ)
if Δnrm > o.thresh
rmul!(Δ, o.thresh / Δnrm)
end
return Δ
end

View File

@ -56,19 +56,23 @@ function stop()
throw(StopException())
end
maketuple(x) = (x,)
maketuple(x::Tuple) = x
"""
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
backpropagation and call the optimizer `opt`.
For each datapoint `d` in `data`, assumed to be a tuple, compute the gradient of `loss(d...)`
with respect to `params`, and call the optimizer `opt`.
In case datapoints `d` are of numeric array type, assume no splatting is needed
and compute the gradient of `loss(d)`.
If `data` yields a tuple mini-batch `d` under iteration, it will be splatted in the function call
`loss(d...)`, otherwise `loss(d)` will be called for non-tuple mini-batches.
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
@ -79,14 +83,8 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
if d isa AbstractArray{<:Number}
gs = gradient(ps) do
loss(d)
end
else
gs = gradient(ps) do
loss(d...)
end
gs = gradient(ps) do
loss(maketuple(d)...)
end
update!(opt, ps, gs)
cb()

View File

@ -246,10 +246,6 @@ function _restructure(m, xs)
end
end
@adjoint function _restructure(m, xs)
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
end
"""
destructure(m)

View File

@ -69,7 +69,6 @@ if CuArrays.has_cudnn()
@info "Testing Flux/CUDNN"
include("cudnn.jl")
include("curnn.jl")
include("layers.jl")
else
@warn "CUDNN unavailable, not testing GPU DNN support"
end

View File

@ -1,98 +0,0 @@
# Test layers and data/model movements on and off the GPU
# Add tests for layers and their gradients on the GPU
# Most of the forward passes should be fine being applied
# to bitstype objects, but this gives higher coverage for our use-cases
# Check that getting the gradients does not throw
# generic movement tests
@testset "Basic GPU Movement" begin
@test gradient(x -> sum(gpu(x)), rand(3,3)) isa Tuple
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
end
# TODO: These layers get into scalar indexing
# `AlphaDropout` throws a compilation error on GPUs,
# whereas, the rest are scalar indexing issues.
const BROKEN_LAYERS = [DepthwiseConv,
AlphaDropout,
InstanceNorm,
GroupNorm]
function gradtest(name::String, layers::Vector, xs = nothing, args...)
isnothing(xs) && error("Missing input to test the layers against.")
@testset "$name GPU grad tests" begin
for layer in layers
@testset "$layer GPU grad test" begin
l = gpu(layer(args...))
xs = gpu(xs)
if any(x -> isa(l, x), BROKEN_LAYERS)
ps = Flux.params(l)
@test_broken gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
else
ps = Flux.params(l)
@test gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
gs = gradient(() -> sum(l(xs)), ps)
# Handle pooling layers
if !isempty(ps)
@test gs[first(ps)] isa Flux.CuArrays.CuArray
end
end
end
end
end
end
# Repeats from Conv, CrossCor
r = rand(Float32, 28, 28, 1, 1)
conv_layers = [Conv, ConvTranspose, CrossCor, DepthwiseConv]
gradtest("Conv", conv_layers, r, (2,2), 1=>3)
pooling_layers = [MaxPool, MeanPool]
gradtest("Pooling", pooling_layers, r, (2,2))
dropout_layers = [Dropout, AlphaDropout]
gradtest("Dropout", dropout_layers, r, 0.5f0)
norm_layers = [LayerNorm, BatchNorm]
gradtest("Normalising", norm_layers, rand(Float32, 28,28,3,1), 1)
instancenorm = [InstanceNorm]
gradtest("InstanceNorm", instancenorm, r, 1)
groupnorm = [GroupNorm]
gradtest("GroupNorm", groupnorm, rand(Float32, 28,28,3,1), 3, 1)
const stateless_layers = [Flux.mse,
Flux.crossentropy,
Flux.logitcrossentropy,
Flux.normalise]
const stateless_layers_broadcasted = [Flux.binarycrossentropy,
Flux.logitbinarycrossentropy]
function stateless_gradtest(f, args...)
@test gradient((args...) -> sum(f(args...)), args...)[1] isa CuArray
end
function stateless_gradtest_broadcasted(f, args...)
@test gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
end
@testset "Stateless GPU grad tests" begin
x = gpu(rand(3,3))
y = gpu(rand(3,3))
for layer in stateless_layers
if layer == Flux.normalise
stateless_gradtest(layer, x)
else
stateless_gradtest(layer, x, y)
end
end
for layer in stateless_layers_broadcasted
stateless_gradtest_broadcasted(layer, x, y)
end
end

View File

@ -3,7 +3,6 @@
Y = [1:5;]
d = DataLoader(X, batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3
@ -12,7 +11,6 @@
@test batches[3] == X[:,5:5]
d = DataLoader(X, batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2
@ -20,7 +18,6 @@
@test batches[2] == X[:,3:4]
d = DataLoader((X,), batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@ -28,7 +25,6 @@
@test batches[2] == (X[:,3:4],)
d = DataLoader((X, Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3
@ -42,22 +38,6 @@
@test batches[3][1] == X[:,5:5]
@test batches[3][2] == Y[5:5]
# test with NamedTuple
d = DataLoader((x=X, y=Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == batches[1].x == X[:,1:2]
@test batches[1][2] == batches[1].y == Y[1:2]
@test batches[2][1] == batches[2].x == X[:,3:4]
@test batches[2][2] == batches[2].y == Y[3:4]
@test batches[3][1] == batches[3].x == X[:,5:5]
@test batches[3][2] == batches[3].y == Y[5:5]
# test interaction with `train!`
θ = ones(2)
X = zeros(2, 10)

View File

@ -28,14 +28,6 @@ import Flux: activations
end
@testset "Dense" begin
@testset "constructors" begin
@test size(Dense(10, 100).W) == (100, 10)
@test Dense(rand(100,10), rand(10)).σ == identity
@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
end
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@ -45,6 +37,7 @@ import Flux: activations
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end
@testset "Diagonal" begin

View File

@ -1,26 +1,9 @@
using Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy, flatten,
xlogx, xlogy
σ, binarycrossentropy, logitbinarycrossentropy, flatten
const ϵ = 1e-7
@testset "xlogx & xlogy" begin
@test iszero(xlogx(0))
@test isnan(xlogx(NaN))
@test xlogx(2) 2.0 * log(2.0)
@inferred xlogx(2)
@inferred xlogx(0)
@test iszero(xlogy(0, 1))
@test isnan(xlogy(NaN, 1))
@test isnan(xlogy(1, NaN))
@test isnan(xlogy(NaN, NaN))
@test xlogy(2, 3) 2.0 * log(3.0)
@inferred xlogy(2, 3)
@inferred xlogy(0, 1)
end
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
@ -29,15 +12,15 @@ end
@testset "mse" begin
@test mse(ŷ, y) (.1^2 + .9^2)/2
end
@testset "mae" begin
@test Flux.mae(ŷ, y) 1/2
end
@testset "huber_loss" begin
@test Flux.huber_loss(ŷ, y) 0.20500000000000002
end
end
y = [123.0,456.0,789.0]
ŷ = [345.0,332.0,789.0]
@testset "msle" begin
@ -52,7 +35,6 @@ end
lossvalue = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy([0.1,0.0,0.9], [0.1,0.0,0.9]) crossentropy([0.1,0.9], [0.1,0.9])
@test crossentropy(ŷ, y) lossvalue
end
@ -81,47 +63,46 @@ end
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end
y = [1 2 3]
ŷ = [4.0 5.0 6.0]
@testset "kldivergence" begin
@test Flux.kldivergence([0.1,0.0,0.9], [0.1,0.0,0.9]) Flux.kldivergence([0.1,0.9], [0.1,0.9])
@test Flux.kldivergence(ŷ, y) -1.7661057888493457
@test Flux.kldivergence(y, y) 0
@test Flux.kldivergence(y, y) 0
end
y = [1 2 3 4]
ŷ = [5.0 6.0 7.0 8.0]
@testset "hinge" begin
@test Flux.hinge(ŷ, y) 0
@test Flux.hinge(y, 0.5 .* y) 0.125
end
@testset "squared_hinge" begin
@test Flux.squared_hinge(ŷ, y) 0
@test Flux.squared_hinge(y, 0.5 .* y) 0.0625
end
y = [0.1 0.2 0.3]
ŷ = [0.4 0.5 0.6]
@testset "poisson" begin
@test Flux.poisson(ŷ, y) 0.6278353988097339
@test Flux.poisson(y, y) 0.5044459776946685
end
y = [1.0 0.5 0.3 2.4]
ŷ = [0 1.4 0.5 1.2]
@testset "dice_coeff_loss" begin
@test Flux.dice_coeff_loss(ŷ, y) 0.2799999999999999
@test Flux.dice_coeff_loss(y, y) 0.0
end
@testset "tversky_loss" begin
@test Flux.tversky_loss(ŷ, y) -0.06772009029345383
@test Flux.tversky_loss(ŷ, y, β = 0.8) -0.09490740740740744
@test Flux.tversky_loss(y, y) -0.5576923076923075
end
@testset "no spurious promotions" begin
for T in (Float32, Float64)
y = rand(T, 2)

View File

@ -57,57 +57,35 @@ end
end
@testset "ExpDecay" begin
@testset "Sanity Check" begin
o = ExpDecay(0.2, 0.5, 1, 1e-3)
p = [0.0]
steps = 1:8
eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip)
eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps]
@test eta_actual == eta_expected
end
w = randn(10, 10)
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
w1 = randn(10,10)
loss(x) = Flux.mse(w*x, w1*x)
flag = 1
decay_steps = []
for t = 1:10^5
prev_eta = o.eta
θ = Params([w1])
x = rand(10)
θ̄ = gradient(() -> loss(x), θ)
prev_grad = collect(θ̄[w1])
delta = Optimise.apply!(o, w1, θ̄[w1])
w1 .-= delta
new_eta = o.eta
if new_eta != prev_eta
push!(decay_steps, t)
end
array = fill(o.eta, size(prev_grad))
if array .* prev_grad != delta
flag = 0
end
end
@test flag == 1
# Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1).
ground_truth = []
for i in 1:4
push!(ground_truth, 1000*i) # Expected decay steps for this example.
end
@test decay_steps == ground_truth
@test o.eta == o.clip
end
@testset "Clipping" begin
w = randn(10, 10)
loss(x) = sum(w * x)
θ = Params([w])
x = 1000 * randn(10)
= gradient(() -> loss(x), θ)[w]
w̄_value = Optimise.apply!(ClipValue(1.0), w, copy())
@test all(w̄_value .<= 1)
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy())
@test norm(w̄_norm) <= 1
end
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
w1 = randn(10,10)
loss(x) = Flux.mse(w*x, w1*x)
flag = 1
decay_steps = []
for t = 1:10^5
prev_eta = o.eta
θ = Params([w1])
x = rand(10)
θ̄ = gradient(() -> loss(x), θ)
prev_grad = collect(θ̄[w1])
delta = Optimise.apply!(o, w1, θ̄[w1])
w1 .-= delta
new_eta = o.eta
if new_eta != prev_eta
push!(decay_steps, t)
end
array = fill(o.eta, size(prev_grad))
if array .* prev_grad != delta
flag = 0
end
end
@test flag == 1
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
ground_truth = []
for i in 1:11
push!(ground_truth, 1000*i) # Expected decay steps for this example.
end
@test decay_steps == ground_truth
@test o.eta == o.clip
end