diff --git a/README.md b/README.md index 4785c55c..0baa74d4 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

-[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) +[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](http://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602) Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. @@ -12,6 +12,18 @@ julia> Pkg.add("Flux") See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. +If you use Flux in research, please cite the following paper: + +``` +@article{innes:2018, + author = {Mike Innes}, + title = {Flux: Elegant Machine Learning with Julia}, + journal = {Journal of Open Source Software}, + year = {2018}, + doi = {10.21105/joss.00602}, +} +``` + ## Features Flux has powerful high-level features, and common architectures can be defined in a few lines. @@ -79,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here. For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel. + +## Related Packages + +Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models. + +[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets. diff --git a/docs/make.jl b/docs/make.jl index d7f14d8e..ed6a8c8b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib], "One-Hot Encoding" => "data/onehot.md", "GPU Support" => "gpu.md", "Saving & Loading" => "saving.md", + "Internals" => + ["Backpropagation" => "internals/tracker.md"], "Community" => "community.md"]) deploydocs( diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md new file mode 100644 index 00000000..b9addc34 --- /dev/null +++ b/docs/src/internals/tracker.md @@ -0,0 +1,156 @@ +# Flux.Tracker + +Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module. + +```julia +julia> using Flux.Tracker +``` + +The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters: + +```julia +julia> W = param([1 2; 3 4]) +Tracked 2×2 Array{Float64,2}: + 1.0 2.0 + 3.0 4.0 + +julia> x = param([5, 6]) +Tracked 2-element Array{Float64,1}: + 5.0 + 6.0 + +julia> y = W*x +Tracked 2-element Array{Float64,1}: + 17.0 + 39.0 +``` + +The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays: + +```julia +julia> Tracker.back!(y, [1, -1]) + +julia> W.grad +2×2 Array{Float64,2}: + 5.0 6.0 +-5.0 -6.0 + +julia> x.grad +2-element Array{Float64,1}: + -2.0 + -2.0 +``` + +## Internals + +All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field. + +```julia +julia> x.tracker +Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0]) +``` + +The `Tracker` stores the value and gradient of a given object, which we've seen before. + +```julia +julia> x.tracker.data +2-element Array{Float64,1}: + 5.0 + 6.0 + +julia> x.tracker.grad +2-element Array{Float64,1}: + -2.0 + -2.0 +``` + +The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this: + +```julia +julia> Tracker.Call(+, 1, 2) +Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2)) +``` + +In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`. + +```julia +julia> y.tracker.f +Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0]))) +``` + +Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*). + +When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling + +```julia +Tracker.back(*, [1, -1], W, x) +``` + +which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters. + +## Custom Gradients + +We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`: + +```julia +julia> minus(a, b) = a - b +``` + +Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: + +```julia +julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus (generic function with 2 methods) +``` + +`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened: + +```julia +julia> a, b = param([6,5,4]), param([1,2,3]) +(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])) + +julia> c = minus(a, b) +Tracked 3-element Array{Float64,1}: + 5.0 + 3.0 + 1.0 + +julia> c.tracker.f +Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))) +``` + +Finally, we have to specify the gradient of `minus`. + +```julia +julia> Tracker.back(::typeof(minus), Δ, a, b) = + (Tracker.@back(a, Δ); Tracker.@back(b, -Δ)) +``` + +`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`. + +```julia +julia> Flux.back!(c, 1) + +julia> a.grad +3-element Array{Float64,1}: + 1.0 + 1.0 + 1.0 + +julia> b.grad +3-element Array{Float64,1}: + -1.0 + -1.0 + -1.0 +``` + +## Notes + +For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: + +```julia +minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b) +``` + +`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op. diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 379268b3..c2056bb4 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense -Conv2D +Conv ``` ## Recurrent Layers diff --git a/src/Flux.jl b/src/Flux.jl index 7746ecff..7125630f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,22 +7,23 @@ module Flux using Juno, Requires, Reexport using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D, - Dropout, LayerNorm, BatchNorm, - SGD, ADAM, Momentum, Nesterov, AMSGrad, - param, params, mapleaves, cpu, gpu +export Chain, Dense, RNN, LSTM, GRU, Conv, + Dropout, LayerNorm, BatchNorm, + params, mapleaves, cpu, gpu @reexport using NNlib using NNlib: @fix include("tracker/Tracker.jl") using .Tracker -export Tracker -import .Tracker: data +using .Tracker: data +export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs +export SGD, ADAM, AdaMax, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index acec542e..5d5d9ea0 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise -export update!, params, train!, - SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad +export train!, + SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 42b05dc8..29068983 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) +""" + AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on +the ∞-norm. +""" +AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) + """ ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index c09e6131..29b058ba 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -62,6 +62,18 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ end end +function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + ut = zeros(p.x) + β1p = β1 + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. ut = max(β2 * ut, abs(p.Δ)) + @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) + β1p *= β1 + end +end + function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) vt = zeros(p.x) .+ ϵ diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 0bc63c63..5465dcc3 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -41,7 +41,7 @@ end Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") -back!(::TrackedArray) = error("Use back!(x, Δ)") +back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`") # Fallthrough methods @@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b) -Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...) -Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b) -Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b) - -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...) -Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b) -Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) - -Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b) -Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) -Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) -Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) - function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -108,15 +93,70 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) back(xs, Δ′) end +for f in [:vcat, :hcat] + @eval begin + # This section is a bit of a hack since julia doesn't have a standardised + # promotion mechanism for concatenation yet + # https://github.com/JuliaLang/julia/pull/20815 + + # It should support tracked concatenation with rank ∈ (1,2) with a + # TrackedArray anywhere among the arguments This works as long as base has + # other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. + Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...) + + # It should support tracked concatenation with rank>2 if the TrackedArray is + # first + Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) + Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row + + # It should support tracked concatenation with rank>2 if the TrackedArray is + # second + Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) + Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, + c::Union{TrackedArray,Vector,RowVector,Matrix}...) = + track($f, a, b, c...) # resolves ambiguity introduced by previous row + end +end + function back(::typeof(vcat), Δ, xs...) - i = Base.tail(map(_ -> :, size(Δ))) start = 0 for xsi in xs + i = map(_ -> :, size(xsi)) |> Base.tail @back(xsi, Δ[start+1:start+size(xsi,1), i...]) start += size(xsi, 1) end end +function back(::typeof(hcat), Δ, xs...) + start = 0 + for xsi in xs + if ndims(xsi) == 1 + @back(xsi, Δ[:, start+1]) + else + i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail + @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) + end + start += size(xsi, 2) + end +end + +Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) +Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) + +function back(::typeof(cat), Δ, dims, Xs...) + start = ntuple(i -> 0, Val{ndims(Δ)}) + for xs in Xs + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) + + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) + + @back(xs, reshape(Δ[xs_in_Δ...],size(xs))) + + start = start .+ till_xs + end +end + Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) @@ -156,12 +196,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) -Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = track(mean, xs) Base.mean(xs::TrackedArray, region) = track(mean, xs, region) +Base.maximum(xs::TrackedArray) = track(maximum, xs) +Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region) +Base.minimum(xs::TrackedArray) = track(minimum, xs) +Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region) + LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) @@ -184,6 +228,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ back(::typeof(mean), Δ, xs::TrackedArray, region) = back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) +function back(::typeof(maximum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmax(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(maximum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmax(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmin(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmin(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end + # BLAS Base.diagm(x::TrackedVector) = track(diagm, x) diff --git a/test/optimise.jl b/test/optimise.jl index d57e4985..ae7ec8fe 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] + @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′]) diff --git a/test/tracker.jl b/test/tracker.jl index 0f5b6189..434148f0 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -29,17 +29,94 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) -@test gradtest(vcat, rand(5), rand(3)) -@test gradtest(vcat, rand(5), rand(3), rand(8)) -@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) +function promotiontest(f, A, B, C) + r0 = f(A, B, C) + r1 = f(param(A), B, C) + r2 = f(A, param(B), C) + if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat] + r3 = f(A, B, param(C)) + else + @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved + r3 = r2 + end + r4 = f(param(A), param(B), param(C)) + + @test !isa(r0, TrackedArray) + @test all(isa.([r1,r2,r3,r4], TrackedArray)) + @test r1 == r2 == r3 == r4 + @test r0 == Flux.data(r4) +end + +@testset "concat" begin + cat1(x...) = cat(1, x...) + cat2(x...) = cat(2, x...) + + @testset for vcatf in [vcat, cat1] + @test gradtest(vcatf, rand(5), rand(3)) + @test gradtest(vcatf, rand(5), rand(3), rand(8)) + @test gradtest(vcatf, rand(5)', rand(5)') + @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) + @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) + @test gradtest(vcatf, rand(5), rand(3,1)) + @test gradtest(vcatf, rand(5)', rand(2,5)) + end + + @testset for hcatf in [hcat, cat2] + @test gradtest(hcatf, rand(5), rand(5)) + @test gradtest(hcatf, rand(5)', rand(5)') + @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) + @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) + @test gradtest(hcatf, rand(5)', rand(1,3)) + @test gradtest(hcatf, rand(5), rand(5,2)) +end + + @testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + @test gradtest(catf, rand(5)) + @test gradtest(catf, rand(5)') + @test gradtest(catf, rand(2,5)) + @test gradtest(catf, rand(2,5,3)) + end + + @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) + + @testset "cat($dim, ...)" for dim in 3:5 + catdim = (x...) -> cat(dim, x...) + @test gradtest(catdim, rand(5), rand(5), rand(5)) + @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) + @test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3)) + end + + @test !isa(vcat(rand(2)), TrackedArray) + @test !isa(hcat(rand(2)), TrackedArray) + @test !isa(cat(1,rand(2)), TrackedArray) + + @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) + + @testset "promotiontest" begin + @testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + promotiontest(fcat, rand(2), rand(2), rand(2)) + promotiontest(fcat, rand(2)', rand(2)', rand(2)') + promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2)) + promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2)) + end + + promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) + promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) + promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) + promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) + promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2)) + end +end + @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) -@test gradtest(kron,rand(5), rand(3)) +@test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8)) -@test gradtest(kron,rand(5,1), rand(3,1)) +@test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) @@ -55,6 +132,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) end +@testset "maximum" begin + @test gradtest(maximum, rand(2, 3)) + + @test gradtest(x -> maximum(x, 1), rand(2, 3)) + @test gradtest(x -> maximum(x, 2), rand(2, 3)) + @test gradtest(x -> maximum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4)) +end + +@testset "minimum" begin + @test gradtest(minimum, rand(2, 3)) + + @test gradtest(x -> minimum(x, 1), rand(2, 3)) + @test gradtest(x -> minimum(x, 2), rand(2, 3)) + @test gradtest(x -> minimum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4)) +end + @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5))