Merge branch 'master' into depthwiseconv

This commit is contained in:
Avik Pal 2018-05-30 15:58:35 +05:30 committed by GitHub
commit 33a7f545b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 398 additions and 34 deletions

View File

@ -2,7 +2,7 @@
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/> <img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
</p> </p>
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](http://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602)
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
@ -12,6 +12,18 @@ julia> Pkg.add("Flux")
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
If you use Flux in research, please cite the following paper:
```
@article{innes:2018,
author = {Mike Innes},
title = {Flux: Elegant Machine Learning with Julia},
journal = {Journal of Open Source Software},
year = {2018},
doi = {10.21105/joss.00602},
}
```
## Features ## Features
Flux has powerful high-level features, and common architectures can be defined in a few lines. Flux has powerful high-level features, and common architectures can be defined in a few lines.
@ -79,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here. Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel. For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
## Related Packages
Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.

View File

@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib],
"One-Hot Encoding" => "data/onehot.md", "One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md", "GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md", "Saving & Loading" => "saving.md",
"Internals" =>
["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"]) "Community" => "community.md"])
deploydocs( deploydocs(

View File

@ -0,0 +1,156 @@
# Flux.Tracker
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
```julia
julia> using Flux.Tracker
```
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
```julia
julia> W = param([1 2; 3 4])
Tracked 2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> x = param([5, 6])
Tracked 2-element Array{Float64,1}:
5.0
6.0
julia> y = W*x
Tracked 2-element Array{Float64,1}:
17.0
39.0
```
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
```julia
julia> Tracker.back!(y, [1, -1])
julia> W.grad
2×2 Array{Float64,2}:
5.0 6.0
-5.0 -6.0
julia> x.grad
2-element Array{Float64,1}:
-2.0
-2.0
```
## Internals
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
```julia
julia> x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
```
The `Tracker` stores the value and gradient of a given object, which we've seen before.
```julia
julia> x.tracker.data
2-element Array{Float64,1}:
5.0
6.0
julia> x.tracker.grad
2-element Array{Float64,1}:
-2.0
-2.0
```
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
```julia
julia> Tracker.Call(+, 1, 2)
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
```
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
```julia
julia> y.tracker.f
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
```
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
```julia
Tracker.back(*, [1, -1], W, x)
```
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
```julia
julia> minus(a, b) = a - b
```
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus (generic function with 2 methods)
```
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
```julia
julia> a, b = param([6,5,4]), param([1,2,3])
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
julia> c = minus(a, b)
Tracked 3-element Array{Float64,1}:
5.0
3.0
1.0
julia> c.tracker.f
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
```
Finally, we have to specify the gradient of `minus`.
```julia
julia> Tracker.back(::typeof(minus), Δ, a, b) =
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
```
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
```julia
julia> Flux.back!(c, 1)
julia> a.grad
3-element Array{Float64,1}:
1.0
1.0
1.0
julia> b.grad
3-element Array{Float64,1}:
-1.0
-1.0
-1.0
```
## Notes
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
```
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.

View File

@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks.
```@docs ```@docs
Chain Chain
Dense Dense
Conv2D Conv
``` ```
## Recurrent Layers ## Recurrent Layers

View File

@ -7,22 +7,23 @@ module Flux
using Juno, Requires, Reexport using Juno, Requires, Reexport
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D, export Chain, Dense, RNN, LSTM, GRU, Conv,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, params, mapleaves, cpu, gpu
param, params, mapleaves, cpu, gpu
@reexport using NNlib @reexport using NNlib
using NNlib: @fix using NNlib: @fix
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker
export Tracker using .Tracker: data
import .Tracker: data export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export update!, params, train!, export train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T} struct Param{T}
x::T x::T

View File

@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)

View File

@ -62,6 +62,18 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
end end
end end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
ut = zeros(p.x)
β1p = β1
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. ut = max(β2 * ut, abs(p.Δ))
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ vt = zeros(p.x) .+ ϵ

View File

@ -41,7 +41,7 @@ end
Base.setindex!(xs::TrackedArray, v, i...) = Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`") error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Use back!(x, Δ)") back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
# Fallthrough methods # Fallthrough methods
@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(xs.data) Δ′ = similar(xs.data)
S = size(xs.data) S = size(xs.data)
@ -108,15 +93,70 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
back(xs, Δ′) back(xs, Δ′)
end end
for f in [:vcat, :hcat]
@eval begin
# This section is a bit of a hack since julia doesn't have a standardised
# promotion mechanism for concatenation yet
# https://github.com/JuliaLang/julia/pull/20815
# It should support tracked concatenation with rank ∈ (1,2) with a
# TrackedArray anywhere among the arguments This works as long as base has
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
# It should support tracked concatenation with rank>2 if the TrackedArray is
# first
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
# It should support tracked concatenation with rank>2 if the TrackedArray is
# second
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row
end
end
function back(::typeof(vcat), Δ, xs...) function back(::typeof(vcat), Δ, xs...)
i = Base.tail(map(_ -> :, size(Δ)))
start = 0 start = 0
for xsi in xs for xsi in xs
i = map(_ -> :, size(xsi)) |> Base.tail
@back(xsi, Δ[start+1:start+size(xsi,1), i...]) @back(xsi, Δ[start+1:start+size(xsi,1), i...])
start += size(xsi, 1) start += size(xsi, 1)
end end
end end
function back(::typeof(hcat), Δ, xs...)
start = 0
for xsi in xs
if ndims(xsi) == 1
@back(xsi, Δ[:, start+1])
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
end
start += size(xsi, 2)
end
end
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
function back(::typeof(cat), Δ, dims, Xs...)
start = ntuple(i -> 0, Val{ndims(Δ)})
for xs in Xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
start = start .+ till_xs
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
@ -156,12 +196,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = track(mean, xs) Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region) Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
Base.maximum(xs::TrackedArray) = track(maximum, xs)
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
Base.minimum(xs::TrackedArray) = track(minimum, xs)
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@ -184,6 +228,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./
back(::typeof(mean), Δ, xs::TrackedArray, region) = back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
function back(::typeof(maximum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmax(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmax(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmin(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmin(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end
# BLAS # BLAS
Base.diagm(x::TrackedVector) = track(diagm, x) Base.diagm(x::TrackedVector) = track(diagm, x)

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w]) opt = Opt([w])

View File

@ -29,17 +29,94 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> x', rand(5)) @test gradtest(x -> x', rand(5))
@test gradtest(vcat, rand(5), rand(3)) function promotiontest(f, A, B, C)
@test gradtest(vcat, rand(5), rand(3), rand(8)) r0 = f(A, B, C)
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) r1 = f(param(A), B, C)
r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat]
r3 = f(A, B, param(C))
else
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end
r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
@test r0 == Flux.data(r4)
end
@testset "concat" begin
cat1(x...) = cat(1, x...)
cat2(x...) = cat(2, x...)
@testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8))
@test gradtest(vcatf, rand(5)', rand(5)')
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
@test gradtest(vcatf, rand(5), rand(3,1))
@test gradtest(vcatf, rand(5)', rand(2,5))
end
@testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
@test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2))
end
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
end
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(dim, x...)
@test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
end
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray)
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
promotiontest(fcat, rand(2), rand(2), rand(2))
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
end
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(kron,rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8)) @test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron,rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@ -55,6 +132,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end end
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, 1), rand(2, 3))
@test gradtest(x -> maximum(x, 2), rand(2, 3))
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
end
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, 1), rand(2, 3))
@test gradtest(x -> minimum(x, 2), rand(2, 3))
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5))