Merge branch 'master' of https://github.com/FluxML/Flux.jl into cudnn_batchnorm
This commit is contained in:
commit
564518e448
|
@ -5,4 +5,3 @@ docs/build/
|
|||
docs/site/
|
||||
docs/flux.css
|
||||
deps
|
||||
Manifest.toml
|
||||
|
|
|
@ -4,7 +4,6 @@ os:
|
|||
- linux
|
||||
# - osx
|
||||
julia:
|
||||
- 0.7
|
||||
- 1.0
|
||||
- nightly
|
||||
# uncomment the following lines to override the default test script
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
[[AbstractTrees]]
|
||||
deps = ["Markdown", "Test"]
|
||||
git-tree-sha1 = "feb8b2c99359901e295443c9d0c7e711604acf39"
|
||||
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "a1245c11af6876245c32f82f2067bf67f7da8cee"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "0.4.0"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
||||
[[BinDeps]]
|
||||
deps = ["Compat", "Libdl", "SHA", "URIParser"]
|
||||
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
|
||||
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
|
||||
version = "0.8.10"
|
||||
|
||||
[[BinaryProvider]]
|
||||
deps = ["Libdl", "Pkg", "SHA", "Test"]
|
||||
git-tree-sha1 = "9930c1a6cd49d9fcd7218df6be417e6ae4f1468a"
|
||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.2"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Pkg", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "83cb3d65c37ea1364c2d5bf7bcea41843ba645dc"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.5.0"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
|
||||
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
||||
version = "0.7.5"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Pkg", "Printf", "Reexport", "Test"]
|
||||
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.9.5"
|
||||
|
||||
[[CommonSubexpressions]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
|
||||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "2d9e14d19bad3f9ad5cc5e4cffabc3cfa59de825"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "1.3.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "REPL", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "8fc6e166e24fda04b2b648d4260cdad241788c54"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.14.0"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||
|
||||
[[DelimitedFiles]]
|
||||
deps = ["Mmap"]
|
||||
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
|
||||
[[DiffResults]]
|
||||
deps = ["Compat", "StaticArrays"]
|
||||
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
|
||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "0.0.3"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "0.0.7"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
deps = ["Pkg", "Test"]
|
||||
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
|
||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
||||
version = "0.5.3"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Pkg", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "d8f3e0f19d0d546aa92eb1cd67cd3e515768d9f7"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.0"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.5.3"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
||||
[[Libdl]]
|
||||
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
|
||||
|
||||
[[LinearAlgebra]]
|
||||
deps = ["Libdl"]
|
||||
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
|
||||
[[Logging]]
|
||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
|
||||
[[MacroTools]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.4.4"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
||||
|
||||
[[Media]]
|
||||
deps = ["MacroTools", "Test"]
|
||||
git-tree-sha1 = "9f390271c9a43dcbe908a10b5b9632cf58cbab5b"
|
||||
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
||||
version = "0.4.1"
|
||||
|
||||
[[Missings]]
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.3.1"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "d7f65ad9734adea3c5a4c473bc65b365f8afbb2b"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.4.2"
|
||||
|
||||
[[NaNMath]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
|
||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
version = "0.3.2"
|
||||
|
||||
[[OrderedCollections]]
|
||||
deps = ["Pkg", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.0.2"
|
||||
|
||||
[[Pkg]]
|
||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
|
||||
[[Printf]]
|
||||
deps = ["Unicode"]
|
||||
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
|
||||
[[Profile]]
|
||||
deps = ["Printf"]
|
||||
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
|
||||
|
||||
[[REPL]]
|
||||
deps = ["InteractiveUtils", "Markdown", "Sockets"]
|
||||
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
||||
|
||||
[[Random]]
|
||||
deps = ["Serialization"]
|
||||
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
|
||||
[[Reexport]]
|
||||
deps = ["Pkg"]
|
||||
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
|
||||
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Requires]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
|
||||
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
version = "0.5.2"
|
||||
|
||||
[[SHA]]
|
||||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
|
||||
[[Serialization]]
|
||||
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
||||
|
||||
[[SharedArrays]]
|
||||
deps = ["Distributed", "Mmap", "Random", "Serialization"]
|
||||
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
|
||||
|
||||
[[Sockets]]
|
||||
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
|
||||
|
||||
[[SortingAlgorithms]]
|
||||
deps = ["DataStructures", "Random", "Test"]
|
||||
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
|
||||
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
|
||||
version = "0.3.1"
|
||||
|
||||
[[SparseArrays]]
|
||||
deps = ["LinearAlgebra", "Random"]
|
||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
||||
git-tree-sha1 = "c35c9c76008babf4d658060fc64aeb369a41e7bd"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.7.1"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Pkg", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "ebc5c2a27d91d5ec611a9861168182e2168effd3"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.9.2"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "723193a13e8078cec6dcd0b8fe245c8bfd81690e"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.25.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["DelimitedFiles", "Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.8.1"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
|
||||
uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
||||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||
|
||||
[[ZipFile]]
|
||||
deps = ["Printf", "Test"]
|
||||
git-tree-sha1 = "c191e56c849b1784cacbf7cd5e52cc672f1ae2db"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.7.0"
|
|
@ -0,0 +1,24 @@
|
|||
name = "Flux"
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
|
||||
[deps]
|
||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
|
@ -10,6 +10,12 @@ MaxPool
|
|||
MeanPool
|
||||
```
|
||||
|
||||
## Additional Convolution Layers
|
||||
|
||||
```@docs
|
||||
DepthwiseConv
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
||||
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
|
||||
|
|
|
@ -6,7 +6,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
|||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
|
@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
|||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, InvDecay, ExpDecay, WeightDecay
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using NNlib: conv
|
||||
using NNlib: conv, depthwiseconv
|
||||
|
||||
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||
|
||||
|
@ -30,8 +30,8 @@ Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
|
@ -51,6 +51,56 @@ function Base.show(io::IO, l::Conv)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
DepthwiseConv(size, in)
|
||||
DepthwiseConv(size, in=>mul)
|
||||
DepthwiseConv(size, in=>mul, relu)
|
||||
|
||||
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `mul` specify the number of input channels and channel multiplier respectively.
|
||||
In case the `mul` is not specified it is taken as 1.
|
||||
|
||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct DepthwiseConv{N,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
end
|
||||
|
||||
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0) where {T,N} =
|
||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
|
||||
stride = 1, pad = 0) where N =
|
||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
@treelike DepthwiseConv
|
||||
|
||||
function (c::DepthwiseConv)(x)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::DepthwiseConv)
|
||||
print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2])
|
||||
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
MaxPool(k)
|
||||
|
@ -60,9 +110,9 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
|
|||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct MaxPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
|
|
|
@ -1,23 +1,12 @@
|
|||
module Optimise
|
||||
|
||||
export train!,
|
||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δ::T
|
||||
end
|
||||
|
||||
Param(x::AbstractArray) = Param(x, zero(x))
|
||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||
|
||||
include("optimisers.jl")
|
||||
include("interface.jl")
|
||||
include("train.jl")
|
||||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
Param(x::TrackedArray) = Param(x.data, x.grad)
|
||||
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
include("deprecations.jl")
|
||||
|
||||
end
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
using Base: depwarn
|
||||
using Flux: Params
|
||||
|
||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||
|
||||
# legacy update rule
|
||||
updaterule(opt, ps) = () -> update!(opt, ps)
|
||||
|
||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||
|
||||
ps = params
|
||||
opt = Descent(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
|
||||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||
|
||||
ps = params
|
||||
opt = Momentum(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||
|
||||
ps = params
|
||||
opt = Nesterov(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||
|
||||
ps = params
|
||||
opt = RMSProp(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
|
||||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||
|
||||
ps = params
|
||||
opt = ADAGrad(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
|
||||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||
|
||||
ps = params
|
||||
opt = ADADelta(ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AdaMax(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AMSGrad(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = NADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAMW(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
# Old training loop
|
||||
|
||||
struct OldOptimiser
|
||||
func
|
||||
end
|
||||
|
||||
update!(opt::OldOptimiser, ps) = opt.func()
|
||||
|
||||
# Train function
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
||||
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
||||
end
|
|
@ -1,110 +0,0 @@
|
|||
call(f, xs...) = f(xs...)
|
||||
|
||||
# note for optimisers: set to zero
|
||||
# p.Δ at the end of the weights update
|
||||
function optimiser(ps, fs...)
|
||||
ps = [Param(p) for p in ps]
|
||||
fs = map(ps) do p
|
||||
os = map(f -> f(p), fs)
|
||||
() -> foreach(call, os)
|
||||
end
|
||||
() -> foreach(call, fs)
|
||||
end
|
||||
|
||||
"""
|
||||
SGD(params, η = 0.1; decay = 0)
|
||||
|
||||
Classic gradient descent optimiser with learning rate `η`.
|
||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
||||
|
||||
Supports inverse decaying learning rate if the `decay` argument is provided.
|
||||
"""
|
||||
SGD(ps, η = 0.1; decay = 0) =
|
||||
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
|
||||
|
||||
"""
|
||||
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
|
||||
|
||||
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
|
||||
"""
|
||||
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
||||
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
|
||||
|
||||
"""
|
||||
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
|
||||
|
||||
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
|
||||
"""
|
||||
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
||||
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
|
||||
|
||||
"""
|
||||
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
||||
|
||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||
choice for recurrent networks.
|
||||
"""
|
||||
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
||||
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
"""
|
||||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
the ∞-norm.
|
||||
"""
|
||||
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
||||
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
Parameters don't need tuning.
|
||||
"""
|
||||
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
|
||||
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
||||
|
||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
||||
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
|
||||
|
||||
"""
|
||||
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
|
||||
than learning rate don't need tuning.
|
||||
"""
|
||||
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
@ -1,130 +1,304 @@
|
|||
function descent(p::Param, η::Real)
|
||||
function ()
|
||||
@. p.x -= η * p.Δ
|
||||
@. p.Δ = 0
|
||||
using Flux
|
||||
using Base: @get!
|
||||
using MacroTools: @forward
|
||||
|
||||
const ϵ = 1e-8
|
||||
|
||||
# TODO: should use weak refs
|
||||
|
||||
"""
|
||||
Descent(η)
|
||||
|
||||
Classic gradient descent optimiser with learning rate `η`.
|
||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
||||
"""
|
||||
mutable struct Descent
|
||||
eta::Float64
|
||||
end
|
||||
|
||||
Descent() = Descent(0.1)
|
||||
|
||||
function update!(o::Descent, x, Δ)
|
||||
Δ .*= o.eta
|
||||
end
|
||||
|
||||
"""
|
||||
Momentum(params, η = 0.01; ρ = 0.9)
|
||||
|
||||
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||
"""
|
||||
mutable struct Momentum
|
||||
eta::Float64
|
||||
rho::Float64
|
||||
velocity::IdDict
|
||||
end
|
||||
|
||||
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||
|
||||
function update!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
|
||||
"""
|
||||
Nesterov(eta, ρ = 0.9)
|
||||
|
||||
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
||||
"""
|
||||
mutable struct Nesterov
|
||||
eta::Float64
|
||||
rho::Float64
|
||||
velocity::IdDict
|
||||
end
|
||||
|
||||
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||
|
||||
function update!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@. v = ρ*v - η*Δ
|
||||
@. Δ = -d
|
||||
end
|
||||
|
||||
"""
|
||||
RMSProp(η = 0.001, ρ = 0.9)
|
||||
|
||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||
choice for recurrent networks.
|
||||
"""
|
||||
mutable struct RMSProp
|
||||
eta::Float64
|
||||
rho::Float64
|
||||
acc::IdDict
|
||||
end
|
||||
|
||||
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||
|
||||
function update!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAM(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
"""
|
||||
mutable struct ADAM
|
||||
eta::Float64
|
||||
beta::Tuple{Float64,Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
||||
|
||||
function update!(o::ADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
||||
o.state[x] = (mt, vt, βp .* β)
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
the ∞-norm.
|
||||
"""
|
||||
mutable struct AdaMax
|
||||
eta::Float64
|
||||
beta::Tuple{Float64,Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
||||
|
||||
function update!(o::AdaMax, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. ut = max(β[2] * ut, abs(Δ))
|
||||
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
|
||||
o.state[x] = (mt, ut, βp .* β)
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
ADAGrad(η = 0.1; ϵ = 1e-8)
|
||||
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
Parameters don't need tuning.
|
||||
"""
|
||||
mutable struct ADAGrad
|
||||
eta::Float64
|
||||
acc::IdDict
|
||||
end
|
||||
|
||||
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function update!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADADelta(ρ = 0.9, ϵ = 1e-8)
|
||||
|
||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct ADADelta
|
||||
rho::Float64
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
||||
|
||||
function update!(o::ADADelta, x, Δ)
|
||||
ρ = o.rho
|
||||
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= √Δacc/ (√acc + ϵ)
|
||||
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct AMSGrad
|
||||
eta::Float64
|
||||
beta::Tuple{Float64, Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||
|
||||
function update!(o::AMSGrad, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
||||
@. v̂t = max.(v̂t, vt)
|
||||
@. Δ = η * mt / (√v̂t + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
NADAM(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct NADAM
|
||||
eta::Float64
|
||||
beta::Tuple{Float64, Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||
|
||||
function update!(o::NADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
β1p, β2p = o.beta
|
||||
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η
|
||||
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||
Optimiser(ADAM(η, β), WeightDecay(wd))
|
||||
|
||||
# Compose optimizers
|
||||
|
||||
"""
|
||||
Optimiser(a, b, c...)
|
||||
|
||||
Combine several optimisers into one; each optimiser produces a modified gradient
|
||||
that will be fed into the next, and this is finally applied to the parameter as
|
||||
usual.
|
||||
"""
|
||||
mutable struct Optimiser
|
||||
os::Vector{Any}
|
||||
end
|
||||
|
||||
Optimiser(o...) = Optimiser(Any[o...])
|
||||
|
||||
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
||||
@forward Optimiser.os Base.iterate
|
||||
|
||||
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
||||
|
||||
function update!(o::Optimiser, x, Δ)
|
||||
for opt in o.os
|
||||
Δ = update!(opt, x, Δ)
|
||||
end
|
||||
return Δ
|
||||
end
|
||||
|
||||
# Ref: https://arxiv.org/abs/1711.05101.pdf
|
||||
function descentweightdecay(p::Param, η::Real, γ::Real)
|
||||
function ()
|
||||
@. p.x = p.x - η * (p.Δ + γ * p.x)
|
||||
@. p.Δ = 0
|
||||
mutable struct InvDecay
|
||||
gamma::Float64
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||
|
||||
function update!(o::InvDecay, x, Δ)
|
||||
γ = o.gamma
|
||||
n = get!(o.state, x, 1)
|
||||
Δ .*= 1 / (1 + γ * n)
|
||||
o.state[x] = n + 1
|
||||
return Δ
|
||||
end
|
||||
|
||||
mutable struct ExpDecay
|
||||
eta::Float64
|
||||
decay::Float64
|
||||
step::Int64
|
||||
clip::Float64
|
||||
current::IdDict
|
||||
end
|
||||
|
||||
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
η, s, decay = o.eta, o.step, o.decay
|
||||
n = o.current[x] = get(o.current, x, 0) + 1
|
||||
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
||||
η = max(η * decay^(s / n), o.clip)
|
||||
o.eta = η
|
||||
end
|
||||
@. Δ *= decay
|
||||
end
|
||||
|
||||
function momentum(p::Param, ρ, η)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
@. v = ρ * v - η * p.Δ
|
||||
@. p.Δ = -v
|
||||
end
|
||||
mutable struct WeightDecay
|
||||
wd::Real
|
||||
end
|
||||
|
||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
||||
function nesterov(p::Param, ρ, η)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
||||
@. v = ρ*v - η*p.Δ
|
||||
@. p.Δ = -d
|
||||
end
|
||||
end
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
end
|
||||
end
|
||||
|
||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||
acc = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. acc += p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
end
|
||||
end
|
||||
|
||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zero(p.x)
|
||||
Δacc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
|
||||
end
|
||||
end
|
||||
|
||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||
@. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η
|
||||
β1p *= β1
|
||||
β2p *= β2
|
||||
end
|
||||
end
|
||||
|
||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zero(p.x)
|
||||
ut = zero(p.x)
|
||||
β1p = β1
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. ut = max(β2 * ut, abs(p.Δ))
|
||||
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
|
||||
β1p *= β1
|
||||
end
|
||||
end
|
||||
|
||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x) .+ ϵ
|
||||
v̂t = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||
@. v̂t = max.(v̂t, vt)
|
||||
@. p.Δ = η * mt / √v̂t
|
||||
end
|
||||
end
|
||||
|
||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η
|
||||
β1p *= β1
|
||||
β2p *= β2
|
||||
end
|
||||
end
|
||||
|
||||
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
||||
|
||||
function expdecay(p::Param, γ::Real)
|
||||
if γ != 0
|
||||
return () -> p.Δ .+= γ .* p.x
|
||||
else
|
||||
return () -> nothing
|
||||
end
|
||||
end
|
||||
|
||||
function invdecay(p::Param, γ::Real)
|
||||
if γ != 0
|
||||
n = 0
|
||||
return () -> begin
|
||||
p.Δ .*= 1 / (1 + γ * n)
|
||||
n += 1
|
||||
end
|
||||
else
|
||||
return () -> nothing
|
||||
end
|
||||
function update!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
using Juno
|
||||
using Flux.Tracker: back!
|
||||
using Flux.Tracker: data, grad, back!
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, xs)
|
||||
for x in xs
|
||||
Δ = update!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
end
|
||||
end
|
||||
|
||||
# Callback niceties
|
||||
call(f, xs...) = f(xs...)
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
|
@ -35,7 +45,7 @@ function stop()
|
|||
end
|
||||
|
||||
"""
|
||||
train!(loss, data, opt)
|
||||
train!(model, loss, data, opt)
|
||||
|
||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||
backpropagation and calls the optimizer `opt`.
|
||||
|
@ -44,7 +54,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
|
|||
every 10 seconds:
|
||||
|
||||
```julia
|
||||
Flux.train!(loss, data, opt,
|
||||
Flux.train!(model, loss, data, opt,
|
||||
cb = throttle(() -> println("training"), 10))
|
||||
```
|
||||
|
||||
|
@ -52,14 +62,14 @@ The callback can return `:stop` to interrupt the training loop.
|
|||
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
function train!(loss, ps, data, opt; cb = () -> ())
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
update!(opt, ps)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
|
|
@ -68,9 +68,9 @@ end
|
|||
|
||||
include("idset.jl")
|
||||
include("back.jl")
|
||||
include("scalar.jl")
|
||||
include("array.jl")
|
||||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
include("lib/array.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
|
|
@ -19,62 +19,78 @@ function scan(x)
|
|||
return
|
||||
end
|
||||
|
||||
function back_(c::Call, Δ)
|
||||
function back_(c::Call, Δ, once)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach(back, c.args, data.(Δs))
|
||||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Nothing}, Δ) = nothing
|
||||
back_(::Call{Nothing}, Δ, once) = nothing
|
||||
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ)
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if ref > 0 || isdefined(x, :grad)
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
else
|
||||
x.grad = Δ
|
||||
end
|
||||
ref == 0 && back_(x.f, x.grad)
|
||||
grad = if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
elseif ref > 0
|
||||
x.grad = Δ
|
||||
else
|
||||
ref == 0 && back_(x.f, Δ)
|
||||
Δ
|
||||
end
|
||||
if ref == 0
|
||||
back_(x.f, grad, once)
|
||||
once && !x.isleaf && (x.f = Call(missing, ()))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Nothing, _) = return
|
||||
back(::Nothing, Δ, once) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
# (but only if you re-use intermediate values between passes)
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x, Δ)
|
||||
function back!(x, Δ; once = true)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
back(tracker(x), Δ)
|
||||
back(tracker(x), Δ, once)
|
||||
return
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
params::IdSet
|
||||
Params(xs) = new(IdSet(xs))
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.params Base.iterate, Base.length
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.params, ", ")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
|
@ -91,12 +107,12 @@ Grads() = Grads(IdDict())
|
|||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
|
|
|
@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
|
|||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!(s::IdSet) = s
|
||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
|
|
@ -33,8 +33,10 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
|||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
@isdefined(A) ?
|
||||
print(io, "TrackedArray{…,$A}") :
|
||||
invoke(show, Tuple{IO,DataType}, io, t)
|
||||
|
||||
function Base.summary(io::IO, x::TrackedArray)
|
||||
print(io, "Tracked ")
|
||||
|
@ -82,6 +84,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
|||
end
|
||||
end
|
||||
|
||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||
|
||||
@grad function view(x::AbstractArray, inds...)
|
||||
view(data(x), inds...), function (Δ)
|
||||
grad_output = zero(x)
|
||||
subgrad = view(grad_output, inds...)
|
||||
subgrad[:] = data(Δ)
|
||||
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
@ -309,8 +322,8 @@ end
|
|||
|
||||
# BLAS
|
||||
|
||||
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
|
||||
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
|
||||
|
||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
|
@ -330,7 +343,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
|
|||
# NNlib
|
||||
|
||||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
|
||||
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
|
@ -340,6 +353,16 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
|
||||
@grad depthwiseconv(x, w; kw...) =
|
||||
depthwiseconv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:depthwiseconv,
|
||||
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
@ -424,26 +447,28 @@ end
|
|||
using Requires
|
||||
|
||||
# https://github.com/FluxML/Flux.jl/issues/353
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
isflat(bc) && return bc
|
||||
args = cat_nested(bc)
|
||||
let makeargs = make_makeargs(bc), f = bc.f
|
||||
newf = @inline function(args::Vararg{Any,N}) where N
|
||||
f(makeargs(args...)...)
|
||||
if VERSION < v"1.1.0-DEV.548"
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
isflat(bc) && return bc
|
||||
args = cat_nested(bc)
|
||||
let makeargs = make_makeargs(bc), f = bc.f
|
||||
newf = @inline function(args::Vararg{Any,N}) where N
|
||||
f(makeargs(args...)...)
|
||||
end
|
||||
return Broadcasted{Style}(newf, args, bc.axes)
|
||||
end
|
||||
return Broadcasted{Style}(newf, args, bc.axes)
|
||||
end
|
||||
end
|
||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||
bc = t[1]
|
||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||
let makeargs = make_makeargs(makeargs, bc.args)
|
||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||
return @inline function(args::Vararg{Any,N}) where N
|
||||
args1 = makeargs(args...)
|
||||
a, b = headargs(args1...), tailargs(args1...)
|
||||
(f(a...), b...)
|
||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||
bc = t[1]
|
||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||
let makeargs = make_makeargs(makeargs, bc.args)
|
||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||
return @inline function(args::Vararg{Any,N}) where N
|
||||
args1 = makeargs(args...)
|
||||
a, b = headargs(args1...), tailargs(args1...)
|
||||
(f(a...), b...)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
|
|||
|
||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||
|
||||
function back!(x::TrackedReal)
|
||||
function back!(x::TrackedReal; once = true)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
return back!(x, 1)
|
||||
return back!(x, 1, once = once)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
|
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
|
|||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach(back, c.args[1], data(Δ))
|
||||
function back_(c::Call{typeof(collect)}, Δ, once)
|
||||
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
|
||||
end
|
||||
|
||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
|
@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
|
|||
end
|
||||
|
||||
function params(m)
|
||||
ps = []
|
||||
ps = Params()
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||
|
||||
|
@ -147,9 +145,9 @@ function jacobian(m,x)
|
|||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i]) # Populate gradient accumulator
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
xp.grad .*= 0 # Reset gradient accumulator
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J'
|
||||
end
|
||||
|
|
|
@ -1,16 +1,40 @@
|
|||
using Flux.Optimise
|
||||
using Flux.Optimise: runall
|
||||
using Flux.Tracker
|
||||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt([w′])
|
||||
for t=1:10^5
|
||||
opt = Opt(0.001)
|
||||
if opt isa Descent || opt isa ADAGrad
|
||||
opt = Opt(0.1)
|
||||
end
|
||||
if opt isa ADADelta
|
||||
opt = Opt(0.9)
|
||||
end
|
||||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
opt()
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Optimiser" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
|
@ -21,9 +45,17 @@ end
|
|||
l = param(1)
|
||||
|
||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
(),
|
||||
Iterators.repeated((), 100),
|
||||
()->(),
|
||||
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
||||
Descent(),
|
||||
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
|
||||
# Test multiple callbacks
|
||||
x = 0
|
||||
fs = [() -> (), () -> x = 1]
|
||||
cbs = runall(fs)
|
||||
cbs()
|
||||
@test x == 1
|
||||
end
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
using NNlib: conv, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
@ -33,6 +33,11 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@testset "indexing & slicing" begin
|
||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||
end
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
|
@ -127,7 +132,7 @@ end
|
|||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||
@test gradtest(x -> diagm(0 => x), rand(3))
|
||||
|
||||
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||
|
@ -181,6 +186,8 @@ end
|
|||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
||||
|
@ -230,10 +237,10 @@ end
|
|||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
Flux.back!(l)
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
x.grad .= 0
|
||||
Flux.back!(l)
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
end
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux
|
||||
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal
|
||||
using StatsBase: std
|
||||
using Random
|
||||
using Test
|
||||
|
@ -64,10 +64,6 @@ end
|
|||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
Random.seed!(0)
|
||||
# initn() should yield a kernel with stddev ~= 1e-2
|
||||
v = initn(10, 10)
|
||||
@test std(v) > 0.9*1e-2
|
||||
@test std(v) < 1.1*1e-2
|
||||
|
||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||
|
|
Loading…
Reference in New Issue