diff --git a/README.md b/README.md
index f8e301ed..10d611f1 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
-
+
-[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/)
+[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/) [](https://doi.org/10.21105/joss.00602)
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
@@ -12,13 +12,25 @@ julia> Pkg.add("Flux")
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
+If you use Flux in research, please cite the following paper:
+
+```
+@article{innes:2018,
+ author = {Mike Innes},
+ title = {Flux: Elegant Machine Learning with Julia},
+ journal = {Journal of Open Source Software},
+ year = {2018},
+ doi = {10.21105/joss.00602},
+}
+```
+
## Features
Flux has powerful high-level features, and common architectures can be defined in a few lines.
```julia
model = Chain(
- Dense(768, 128),
+ Dense(768, 128, σ),
LSTM(128, 256)
LSTM(256, 128)
Dense(128, 10),
@@ -79,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
+
+## Related Packages
+
+Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
+
+[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.
diff --git a/docs/make.jl b/docs/make.jl
index d7f14d8e..ed6a8c8b 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib],
"One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md",
+ "Internals" =>
+ ["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"])
deploydocs(
diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md
new file mode 100644
index 00000000..b9addc34
--- /dev/null
+++ b/docs/src/internals/tracker.md
@@ -0,0 +1,156 @@
+# Flux.Tracker
+
+Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
+
+```julia
+julia> using Flux.Tracker
+```
+
+The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
+
+```julia
+julia> W = param([1 2; 3 4])
+Tracked 2×2 Array{Float64,2}:
+ 1.0 2.0
+ 3.0 4.0
+
+julia> x = param([5, 6])
+Tracked 2-element Array{Float64,1}:
+ 5.0
+ 6.0
+
+julia> y = W*x
+Tracked 2-element Array{Float64,1}:
+ 17.0
+ 39.0
+```
+
+The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
+
+```julia
+julia> Tracker.back!(y, [1, -1])
+
+julia> W.grad
+2×2 Array{Float64,2}:
+ 5.0 6.0
+-5.0 -6.0
+
+julia> x.grad
+2-element Array{Float64,1}:
+ -2.0
+ -2.0
+```
+
+## Internals
+
+All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
+
+```julia
+julia> x.tracker
+Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
+```
+
+The `Tracker` stores the value and gradient of a given object, which we've seen before.
+
+```julia
+julia> x.tracker.data
+2-element Array{Float64,1}:
+ 5.0
+ 6.0
+
+julia> x.tracker.grad
+2-element Array{Float64,1}:
+ -2.0
+ -2.0
+```
+
+The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
+
+```julia
+julia> Tracker.Call(+, 1, 2)
+Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
+```
+
+In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
+
+```julia
+julia> y.tracker.f
+Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
+```
+
+Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
+
+When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
+
+```julia
+Tracker.back(*, [1, -1], W, x)
+```
+
+which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
+
+## Custom Gradients
+
+We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
+
+```julia
+julia> minus(a, b) = a - b
+```
+
+Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
+
+```julia
+julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
+minus (generic function with 2 methods)
+```
+
+`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
+
+```julia
+julia> a, b = param([6,5,4]), param([1,2,3])
+(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
+
+julia> c = minus(a, b)
+Tracked 3-element Array{Float64,1}:
+ 5.0
+ 3.0
+ 1.0
+
+julia> c.tracker.f
+Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
+```
+
+Finally, we have to specify the gradient of `minus`.
+
+```julia
+julia> Tracker.back(::typeof(minus), Δ, a, b) =
+ (Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
+```
+
+`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
+
+```julia
+julia> Flux.back!(c, 1)
+
+julia> a.grad
+3-element Array{Float64,1}:
+ 1.0
+ 1.0
+ 1.0
+
+julia> b.grad
+3-element Array{Float64,1}:
+ -1.0
+ -1.0
+ -1.0
+```
+
+## Notes
+
+For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
+
+```julia
+minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
+minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
+```
+
+`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.
diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md
index cb0c6615..c2056bb4 100644
--- a/docs/src/models/layers.md
+++ b/docs/src/models/layers.md
@@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks.
```@docs
Chain
Dense
-Conv2D
+Conv
```
## Recurrent Layers
@@ -15,6 +15,7 @@ Much like the core layers above, but can be used to process sequence data (as we
```@docs
RNN
LSTM
+GRU
Flux.Recur
```
diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md
index d4325a53..70d06348 100644
--- a/docs/src/models/regularisation.md
+++ b/docs/src/models/regularisation.md
@@ -7,6 +7,7 @@ add the result to the overall loss.
For example, say we have a simple regression.
```julia
+using Flux: crossentropy
m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y)
```
diff --git a/src/Flux.jl b/src/Flux.jl
index c335488d..7d1d66e6 100644
--- a/src/Flux.jl
+++ b/src/Flux.jl
@@ -7,22 +7,23 @@ module Flux
using Juno, Requires, Reexport
using MacroTools: @forward
-export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
- Dropout, LayerNorm, BatchNorm,
- SGD, ADAM, Momentum, Nesterov, AMSGrad,
- param, params, mapleaves, cpu, gpu
+export Chain, Dense, RNN, LSTM, GRU, Conv,
+ Dropout, LayerNorm, BatchNorm,
+ params, mapleaves, cpu, gpu
@reexport using NNlib
using NNlib: @fix
include("tracker/Tracker.jl")
using .Tracker
-export Tracker
-import .Tracker: data
+using .Tracker: data
+export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
+export SGD, ADAM, AdaMax, Momentum, Nesterov,
+ RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
include("utils.jl")
include("onehot.jl")
@@ -32,7 +33,7 @@ include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
-include("layers/normalisation.jl")
+include("layers/normalise.jl")
include("data/Data.jl")
diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl
index 1ee10908..eaa3fe00 100644
--- a/src/cuda/cuda.jl
+++ b/src/cuda/cuda.jl
@@ -4,11 +4,4 @@ using CuArrays
CuArrays.cudnn_available() && include("cudnn.jl")
-import ..Flux.JIT: Shape, restructure
-
-function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
- buf = buf[1:sizeof(sh)]
- reshape(reinterpret(T, buf), size(sh))
-end
-
end
diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl
index 3ac47ef1..2a26b691 100644
--- a/src/data/cmudict.jl
+++ b/src/data/cmudict.jl
@@ -10,13 +10,15 @@ const cache_prefix = "https://cache.julialang.org"
function load()
suffixes = ["", ".phones", ".symbols"]
if isdir(deps("cmudict"))
- if all(isfile.(["cmudict$x" for x in suffixes]))
+ if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes)
return
end
end
+ info("Downloading CMUDict dataset")
mkpath(deps("cmudict"))
for x in suffixes
- download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", deps("cmudict", "cmudict$x"))
+ download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
+ deps("cmudict", "cmudict$x"))
end
end
diff --git a/src/data/mnist.jl b/src/data/mnist.jl
index 132bf219..34bcd50c 100644
--- a/src/data/mnist.jl
+++ b/src/data/mnist.jl
@@ -14,6 +14,7 @@ function load()
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte"]
isfile(file) && continue
+ info("Downloading MNIST dataset")
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
open(file, "w") do io
write(io, GZip.open(read, "$file.gz"))
diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl
index ae9f9261..570fcf5d 100644
--- a/src/data/sentiment.jl
+++ b/src/data/sentiment.jl
@@ -4,10 +4,10 @@ using ZipFile
using ..Data: deps
function load()
- isfile(deps("sentiment.zip")) ||
- download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
- deps("sentiment.zip"))
- return
+ isfile(deps("sentiment.zip")) || return
+ info("Downloading sentiment treebank dataset")
+ download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
+ deps("sentiment.zip"))
end
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
diff --git a/src/layers/conv.jl b/src/layers/conv.jl
index 1ef38a21..7548fc96 100644
--- a/src/layers/conv.jl
+++ b/src/layers/conv.jl
@@ -13,7 +13,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
-Takes the keyword arguments `pad` and `stride`.
+Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
σ::F
@@ -21,17 +21,17 @@ struct Conv{N,F,A,V}
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
+ dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
- stride = 1, pad = 0) where T =
- Conv(σ, w, b, expand(Val{ndims(w) - 2}, stride),
- expand(Val{ndims(w) - 2}, pad))
+ stride = 1, pad = 0, dilation = 1) where T =
+ Conv(σ, w, b, expand.(Val{ndims(w)-2}, (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
- stride = 1, pad = 0) where N =
- Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ;
- stride = stride, pad = pad)
+ stride = 1, pad = 0, dilation = 1) where N =
+ Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
+ stride = stride, pad = pad, dilation = dilation)
Flux.treelike(Conv)
@@ -39,7 +39,7 @@ function (c::Conv)(x)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
- σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
+ σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end
function Base.show(io::IO, l::Conv)
@@ -48,6 +48,3 @@ function Base.show(io::IO, l::Conv)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
-
-# v0.5
-@deprecate Conv2D(args...; kw...) Conv(args...; kw...)
diff --git a/src/layers/normalisation.jl b/src/layers/normalise.jl
similarity index 51%
rename from src/layers/normalisation.jl
rename to src/layers/normalise.jl
index 69854f44..54f5eb56 100644
--- a/src/layers/normalisation.jl
+++ b/src/layers/normalise.jl
@@ -31,15 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
end
+_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
+
function (a::Dropout)(x)
a.active || return x
y = similar(x)
rand!(y)
- q = 1 - a.p
- @inbounds for i=1:length(y)
- y[i] = y[i] > a.p ? 1 / q : 0
- end
- return y .* x
+ y .= _dropout_kernel.(y, a.p, 1 - a.p)
+ return x .* y
end
_testmode!(a::Dropout, test) = (a.active = !test)
@@ -68,70 +67,88 @@ function Base.show(io::IO, l::LayerNorm)
end
"""
- BatchNorm(dims...; λ = identity,
- initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
+ BatchNorm(channels::Integer, σ = identity;
+ initβ = zeros, initγ = ones,
+ ϵ = 1e-8, momentum = .1)
-Batch Normalization Layer for [`Dense`](@ref) layer.
+Batch Normalization layer. The `channels` input should be the size of the
+channel dimension in your data (see below).
+
+Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
+a batch of feature vectors this is just the data dimension, for `WHCN` images
+it's the usual channel dimension.)
+
+`BatchNorm` computes the mean and variance for each each `W×H×1×N` slice and
+shifts them to have a new mean and variance (corresponding to the learnable,
+per-channel `bias` and `scale` parameters).
See [Batch Normalization: Accelerating Deep Network Training by Reducing
- Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf)
+Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
-In the example of MNIST,
-in order to normalize the input of other layer,
-put the `BatchNorm` layer before activation function.
+Example:
```julia
m = Chain(
Dense(28^2, 64),
- BatchNorm(64, λ = relu),
+ BatchNorm(64, relu),
Dense(64, 10),
BatchNorm(10),
softmax)
```
"""
-mutable struct BatchNorm{F,V,N}
+mutable struct BatchNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
- μ # moving mean
- σ # moving std
+ μ::W # moving mean
+ σ::W # moving std
ϵ::N
momentum::N
active::Bool
end
-BatchNorm(dims::Integer...; λ = identity,
+BatchNorm(chs::Integer, λ = identity;
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
- BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
+ BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
+ zeros(chs), ones(chs), ϵ, momentum, true)
function (BN::BatchNorm)(x)
- λ, γ, β = BN.λ, BN.γ, BN.β
+ size(x, ndims(x)-1) == length(BN.β) ||
+ error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
+ γ, β = BN.γ, BN.β
+ dims = length(size(x))
+ channels = size(x, dims-1)
+ affine_shape = ones(Int, dims)
+ affine_shape[end-1] = channels
+ m = prod(size(x)[1:end-2]) * size(x)[end]
if !BN.active
- μ = BN.μ
- σ = BN.σ
+ μ = reshape(BN.μ, affine_shape...)
+ σ = reshape(BN.σ, affine_shape...)
else
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
- m = size(x, 2) # batch size
- μ = mean(x, 2)
- σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
+ axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
+ μ = mean(x, axes)
+ σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
# update moving mean/std
mtm = data(convert(T, BN.momentum))
- BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ)
- BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
+ BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
+ BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
end
- λ.(γ .* ((x .- μ) ./ σ) .+ β)
+ let λ = BN.λ
+ λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
+ end
end
children(BN::BatchNorm) =
- (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
+ (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
- BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
+ BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl
index f4394e61..ccd4fe4c 100644
--- a/src/layers/stateless.jl
+++ b/src/layers/stateless.jl
@@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
- return @fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
+ @fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)
diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl
index acec542e..0c541b93 100644
--- a/src/optimise/Optimise.jl
+++ b/src/optimise/Optimise.jl
@@ -1,7 +1,8 @@
module Optimise
-export update!, params, train!,
- SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
+export train!,
+ SGD, ADAM, AdaMax, Momentum, Nesterov,
+ RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T}
x::T
diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl
index 42b05dc8..3a07f6ce 100644
--- a/src/optimise/interface.jl
+++ b/src/optimise/interface.jl
@@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
+"""
+ AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
+
+[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
+the ∞-norm.
+"""
+AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
+ optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
+
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
@@ -82,3 +91,12 @@ tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
+
+"""
+ NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
+
+[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
+than learning rate don't need tuning.
+"""
+NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
+ optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl
index c09e6131..e3a4ed34 100644
--- a/src/optimise/optimisers.jl
+++ b/src/optimise/optimisers.jl
@@ -27,7 +27,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
- @. p.Δ *= η / (√acc + ϵ)
+ @. p.Δ *= η / √(acc + ϵ)
end
end
@@ -35,7 +35,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
function ()
@. acc += p.Δ^2
- @. p.Δ *= η / √acc
+ @. p.Δ *= η / √(acc + ϵ)
end
end
@@ -56,12 +56,24 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
- @. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η
+ @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
+function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
+ mt = zeros(p.x)
+ ut = zeros(p.x)
+ β1p = β1
+ function ()
+ @. mt = β1 * mt + (1 - β1) * p.Δ
+ @. ut = max(β2 * ut, abs(p.Δ))
+ @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
+ β1p *= β1
+ end
+end
+
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
@@ -74,6 +86,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
end
end
+function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
+ mt = zeros(p.x)
+ vt = zeros(p.x)
+ β1p, β2p = β1, β2
+ function ()
+ @. mt = β1 * mt + (1 - β1) * p.Δ
+ @. vt = β2 * vt + (1 - β2) * p.Δ^2
+ @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η
+ β1p *= β1
+ β2p *= β2
+ end
+end
+
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
diff --git a/src/tracker/array.jl b/src/tracker/array.jl
index 35261abe..7a54d2eb 100644
--- a/src/tracker/array.jl
+++ b/src/tracker/array.jl
@@ -41,7 +41,7 @@ end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
-back!(::TrackedArray) = error("Use back!(x, Δ)")
+back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
# Fallthrough methods
@@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
-Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
-Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
-Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
-Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
-
-Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
-Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
-Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
-Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
-
-Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
-Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
-Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
-Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
-
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(xs.data)
S = size(xs.data)
@@ -108,20 +93,93 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
back(xs, Δ′)
end
+
+_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
+Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
+
+function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
+ Δ′ = similar(xs.data)
+ Δ′ .= 0
+ S = size(xs.data)
+
+ # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
+ for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
+ # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
+ # wrap around based on original size S.
+ src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
+ Δ′[src_idx...] += val
+ end
+ back(xs, Δ′)
+end
+
+
+for f in [:vcat, :hcat]
+ @eval begin
+ # This section is a bit of a hack since julia doesn't have a standardised
+ # promotion mechanism for concatenation yet
+ # https://github.com/JuliaLang/julia/pull/20815
+
+ # It should support tracked concatenation with rank ∈ (1,2) with a
+ # TrackedArray anywhere among the arguments This works as long as base has
+ # other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
+ Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
+
+ # It should support tracked concatenation with rank>2 if the TrackedArray is
+ # first
+ Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
+ Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
+
+ # It should support tracked concatenation with rank>2 if the TrackedArray is
+ # second
+ Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
+ Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
+ c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
+ track($f, a, b, c...) # resolves ambiguity introduced by previous row
+ end
+end
+
function back(::typeof(vcat), Δ, xs...)
- i = Base.tail(map(_ -> :, size(Δ)))
start = 0
for xsi in xs
+ i = map(_ -> :, size(xsi)) |> Base.tail
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
start += size(xsi, 1)
end
end
-Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
- track(reshape, xs, dims...)
+function back(::typeof(hcat), Δ, xs...)
+ start = 0
+ for xsi in xs
+ if ndims(xsi) == 1
+ @back(xsi, Δ[:, start+1])
+ else
+ i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
+ @back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
+ end
+ start += size(xsi, 2)
+ end
+end
-Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
- track(reshape, xs, dims)
+Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
+Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
+
+function back(::typeof(cat), Δ, dims, Xs...)
+ start = ntuple(i -> 0, Val{ndims(Δ)})
+ for xs in Xs
+ dim_xs = 1:ndims(xs)
+ till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
+
+ xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
+
+ @back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
+
+ start = start .+ till_xs
+ end
+end
+
+Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
+Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
+Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
@@ -158,12 +216,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
-Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
+Base.maximum(xs::TrackedArray) = track(maximum, xs)
+Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
+Base.minimum(xs::TrackedArray) = track(minimum, xs)
+Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
+
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@@ -186,6 +248,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
+function back(::typeof(maximum), Δ, xs::TrackedArray)
+ Δ′ = zeros(xs.data)
+ _, i = findmax(xs.data)
+ Δ′[i] = Δ
+ @back(xs, Δ′)
+end
+function back(::typeof(maximum), Δ, xs::TrackedArray, region)
+ Δ′ = zeros(xs.data)
+ _, is = findmax(xs.data, region)
+ Δ′[is] = Δ
+ @back(xs, Δ′)
+end
+function back(::typeof(minimum), Δ, xs::TrackedArray)
+ Δ′ = zeros(xs.data)
+ _, i = findmin(xs.data)
+ Δ′[i] = Δ
+ @back(xs, Δ′)
+end
+function back(::typeof(minimum), Δ, xs::TrackedArray, region)
+ Δ′ = zeros(xs.data)
+ _, is = findmin(xs.data, region)
+ Δ′[is] = Δ
+ @back(xs, Δ′)
+end
+
# BLAS
Base.diagm(x::TrackedVector) = track(diagm, x)
@@ -247,35 +334,35 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
-_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
+_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
-conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
- track(_conv, x, w, stride, pad)
-conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
- track(_conv, x, w, stride, pad)
-conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
- track(_conv, x, w, stride, pad)
+conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
+ track(_conv, x, w, stride, pad, dilation)
+conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
+ track(_conv, x, w, stride, pad, dilation)
+conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
+ track(_conv, x, w, stride, pad, dilation)
-function back(::typeof(_conv), Δ, x, w, stride, pad)
- @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
- @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
+function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
+ @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
+ @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
end
-_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
+_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
-maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
- track(_maxpool, x, k, pad)
+maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
+ track(_maxpool, x, k, pad, stride)
-back_(::typeof(_maxpool), y, Δ, x, k, pad) =
- back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
+back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
+ back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
-_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
+_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
-meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
- track(_meanpool, x, k, pad)
+meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
+ track(_meanpool, x, k, pad, stride)
-back_(::typeof(_meanpool), y, Δ, x, k, pad) =
- back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
+back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
+ back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
# Broadcasting
diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl
index 272f9ba4..755e1f7d 100644
--- a/src/tracker/numeric.jl
+++ b/src/tracker/numeric.jl
@@ -1,4 +1,4 @@
-function gradient(f, xs::AbstractArray...)
+function gradient(f, xs...)
xs = param.(xs)
back!(f(xs...))
grad.(xs)
diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl
index 632046cd..8d0aa29e 100644
--- a/src/tracker/scalar.jl
+++ b/src/tracker/scalar.jl
@@ -19,11 +19,11 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
-Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
- TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
-
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
+Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
+ error("Not implemented: convert tracked $S to tracked $T")
+
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
@@ -91,3 +91,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
+
+# Array collection
+
+function collect(xs)
+ xs = Base.collect(xs)
+ track(Call(collect, xs), data.(xs))
+end
+
+function scan(c::Call{typeof(collect)})
+ foreach(scan, c.args[1])
+end
+
+function back(::typeof(collect), Δ, xs)
+ foreach((x, Δ) -> @back(x, Δ), xs, Δ)
+end
diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl
index 8953dadd..d16ce8f2 100644
--- a/test/cuda/cuda.jl
+++ b/test/cuda/cuda.jl
@@ -21,6 +21,10 @@ cm = gpu(m)
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
+x = [1,2,3]
+cx = gpu(x)
+@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
+
# Fails in Pkg.test ffs
# c = gpu(Conv((2,2),3=>4))
# l = c(gpu(rand(10,10,3,2)))
diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl
index 118a5700..0fdb1021 100644
--- a/test/layers/normalisation.jl
+++ b/test/layers/normalisation.jl
@@ -67,7 +67,7 @@ end
end
# with activation function
- let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]')
+ let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]')
@test m.active
m(x)
@@ -77,4 +77,22 @@ end
x′ = m(x).data
@test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179)
end
+
+ let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
+ y = reshape(permutedims(x, [2, 1, 3]), 2, :)
+ y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
+ @test m(x) == y
+ end
+
+ let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
+ y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
+ y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
+ @test m(x) == y
+ end
+
+ let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
+ y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
+ y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
+ @test m(x) == y
+ end
end
diff --git a/test/optimise.jl b/test/optimise.jl
index d57e4985..c896bb39 100644
--- a/test/optimise.jl
+++ b/test/optimise.jl
@@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
- @testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
+ @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w′ = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w′*x)
opt = Opt([w′])
diff --git a/test/tracker.jl b/test/tracker.jl
index 3c29c3ea..66c08f62 100644
--- a/test/tracker.jl
+++ b/test/tracker.jl
@@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
-using Flux.Tracker: TrackedReal, gradcheck
+using Flux.Tracker: TrackedReal, gradcheck, grad
using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@@ -29,17 +29,97 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> x', rand(5))
-@test gradtest(vcat, rand(5), rand(3))
-@test gradtest(vcat, rand(5), rand(3), rand(8))
-@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
+function promotiontest(f, A, B, C)
+ r0 = f(A, B, C)
+ r1 = f(param(A), B, C)
+ r2 = f(A, param(B), C)
+ if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat]
+ r3 = f(A, B, param(C))
+ else
+ @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
+ r3 = r2
+ end
+ r4 = f(param(A), param(B), param(C))
+
+ @test !isa(r0, TrackedArray)
+ @test all(isa.([r1,r2,r3,r4], TrackedArray))
+ @test r1 == r2 == r3 == r4
+ @test r0 == Flux.data(r4)
+end
+
+@testset "concat" begin
+ cat1(x...) = cat(1, x...)
+ cat2(x...) = cat(2, x...)
+
+ @testset for vcatf in [vcat, cat1]
+ @test gradtest(vcatf, rand(5), rand(3))
+ @test gradtest(vcatf, rand(5), rand(3), rand(8))
+ @test gradtest(vcatf, rand(5)', rand(5)')
+ @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
+ @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
+ @test gradtest(vcatf, rand(5), rand(3,1))
+ @test gradtest(vcatf, rand(5)', rand(2,5))
+ end
+
+ @testset for hcatf in [hcat, cat2]
+ @test gradtest(hcatf, rand(5), rand(5))
+ @test gradtest(hcatf, rand(5)', rand(5)')
+ @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
+ @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
+ @test gradtest(hcatf, rand(5), rand(5), rand(5,2))
+ @test gradtest(hcatf, rand(5)', rand(1,3))
+ @test gradtest(hcatf, rand(5), rand(5,2))
+end
+
+ @testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
+ @test gradtest(catf, rand(5))
+ @test gradtest(catf, rand(5)')
+ @test gradtest(catf, rand(2,5))
+ @test gradtest(catf, rand(2,5,3))
+ end
+
+ @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
+
+ @testset "cat($dim, ...)" for dim in 3:5
+ catdim = (x...) -> cat(dim, x...)
+ @test gradtest(catdim, rand(5), rand(5), rand(5))
+ @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
+ @test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
+ end
+
+ @test !isa(vcat(rand(2)), TrackedArray)
+ @test !isa(hcat(rand(2)), TrackedArray)
+ @test !isa(cat(1,rand(2)), TrackedArray)
+
+ @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
+
+ @testset "promotiontest" begin
+ @testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
+ promotiontest(fcat, rand(2), rand(2), rand(2))
+ promotiontest(fcat, rand(2)', rand(2)', rand(2)')
+ promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
+ promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
+ end
+
+ promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
+ promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
+ promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
+ promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
+ promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
+ end
+end
+
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
-@test gradtest(kron,rand(5), rand(3))
+@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
+@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
+
+@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
-@test gradtest(kron,rand(5,1), rand(3,1))
+@test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@@ -55,6 +135,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
+@testset "maximum" begin
+ @test gradtest(maximum, rand(2, 3))
+
+ @test gradtest(x -> maximum(x, 1), rand(2, 3))
+ @test gradtest(x -> maximum(x, 2), rand(2, 3))
+ @test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
+
+ @test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
+end
+
+@testset "minimum" begin
+ @test gradtest(minimum, rand(2, 3))
+
+ @test gradtest(x -> minimum(x, 1), rand(2, 3))
+ @test gradtest(x -> minimum(x, 2), rand(2, 3))
+ @test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
+
+ @test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
+end
+
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@@ -82,6 +182,21 @@ end
@test param(2)^2 == 4.0
+@testset "reshape" begin
+ x = reshape(param(rand(2,2,2)), 4, 2)
+ @test x isa TrackedArray
+ @test size(x) == (4,2)
+ x = reshape(param([1]), (1,:))
+ @test x isa TrackedArray
+ @test size(x) == (1,1)
+ x = reshape(param(rand(2)), (2,:))
+ @test x isa TrackedArray
+ @test size(x) == (2,1)
+ x = reshape(param(rand(2,2)), (1,:,2))
+ @test x isa TrackedArray
+ @test size(x) == (1,2,2)
+end
+
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
@@ -108,4 +223,13 @@ b = param(rand())
Tracker.back!(b)
@test Tracker.grad(b) == 1
+@testset "collect" begin
+ x, y = param(2), param(3)
+ xy = Tracker.collect([x, y])
+ @test xy isa TrackedArray{Float64}
+ z = xy[1]*xy[2]
+ back!(z)
+ @test grad.((x,y)) == (3, 2)
+end
+
end #testset