Merge branch 'master' into nadam-opt

This commit is contained in:
Tejan Karmali 2018-06-08 16:54:41 +05:30 committed by GitHub
commit 4a24b69976
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 603 additions and 326 deletions

View File

@ -1,6 +1,8 @@
# Флукс
<p align="center">
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
</p>
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/)
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](http://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602)
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
@ -10,6 +12,18 @@ julia> Pkg.add("Flux")
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
If you use Flux in research, please cite the following paper:
```
@article{innes:2018,
author = {Mike Innes},
title = {Flux: Elegant Machine Learning with Julia},
journal = {Journal of Open Source Software},
year = {2018},
doi = {10.21105/joss.00602},
}
```
## Features
Flux has powerful high-level features, and common architectures can be defined in a few lines.
@ -36,7 +50,7 @@ b = param(randn(2))
y(x) = σ.(W * x .+ b)
```
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels with [CUDAnative](https://github.com/JuliaGPU/CUDAnative.jl)! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
```julia
function gpu_add(a, b, c)
@ -77,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
## Related Packages
Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.

View File

@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib],
"One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md",
"Internals" =>
["Backpropagation" => "internals/tracker.md"],
"Community" => "community.md"])
deploydocs(

View File

@ -0,0 +1,156 @@
# Flux.Tracker
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
```julia
julia> using Flux.Tracker
```
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
```julia
julia> W = param([1 2; 3 4])
Tracked 2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> x = param([5, 6])
Tracked 2-element Array{Float64,1}:
5.0
6.0
julia> y = W*x
Tracked 2-element Array{Float64,1}:
17.0
39.0
```
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
```julia
julia> Tracker.back!(y, [1, -1])
julia> W.grad
2×2 Array{Float64,2}:
5.0 6.0
-5.0 -6.0
julia> x.grad
2-element Array{Float64,1}:
-2.0
-2.0
```
## Internals
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
```julia
julia> x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
```
The `Tracker` stores the value and gradient of a given object, which we've seen before.
```julia
julia> x.tracker.data
2-element Array{Float64,1}:
5.0
6.0
julia> x.tracker.grad
2-element Array{Float64,1}:
-2.0
-2.0
```
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
```julia
julia> Tracker.Call(+, 1, 2)
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
```
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
```julia
julia> y.tracker.f
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
```
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
```julia
Tracker.back(*, [1, -1], W, x)
```
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
```julia
julia> minus(a, b) = a - b
```
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus (generic function with 2 methods)
```
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
```julia
julia> a, b = param([6,5,4]), param([1,2,3])
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
julia> c = minus(a, b)
Tracked 3-element Array{Float64,1}:
5.0
3.0
1.0
julia> c.tracker.f
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
```
Finally, we have to specify the gradient of `minus`.
```julia
julia> Tracker.back(::typeof(minus), Δ, a, b) =
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
```
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
```julia
julia> Flux.back!(c, 1)
julia> a.grad
3-element Array{Float64,1}:
1.0
1.0
1.0
julia> b.grad
3-element Array{Float64,1}:
-1.0
-1.0
-1.0
```
## Notes
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
```
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.

View File

@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks.
```@docs
Chain
Dense
Conv2D
Conv
```
## Recurrent Layers
@ -15,6 +15,7 @@ Much like the core layers above, but can be used to process sequence data (as we
```@docs
RNN
LSTM
GRU
Flux.Recur
```

View File

@ -7,6 +7,7 @@ add the result to the overall loss.
For example, say we have a simple regression.
```julia
using Flux: crossentropy
m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y)
```

View File

@ -17,14 +17,6 @@
url = {http://arxiv.org/abs/1712.03112},
}
@online{CuArrays,
author = {Mike Innes},
title = {Generic GPU Kernels},
year = 2017,
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
urldate = {2018-02-16}
}
@online{MLPL,
author = {Mike Innes and others},
title = {On Machine Learning and Programming Languages},
@ -33,10 +25,26 @@
urldate = {2018-02-16}
}
@online{Fusion,
author = {Steven G. Johnson},
title = {More Dots: Syntactic Loop Fusion in Julia},
@online{CuArrays,
author = {Mike Innes and others},
title = {Generic GPU Kernels},
year = 2017,
url = {https://julialang.org/blog/2017/01/moredots},
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
urldate = {2018-02-16}
}
@online{Zoo,
author = {Mike Innes and others},
title = {Flux Model Zoo},
year = 2018,
url = {https://github.com/FluxML/model-zoo/},
urldate = {2018-02-16}
}
@online{Minibatch,
author = {James Bradbury},
title = {Minibatch.jl},
year = 2018,
url = {https://github.com/jekbradbury/Minibatch.jl},
urldate = {2018-02-16}
}

View File

@ -24,8 +24,8 @@ bibliography: paper.bib
Flux is library for machine learning (ML), written using the numerical computing language Julia [@Julia]. The package allows models to be written using Julia's simple mathematical syntax, and applies automatic differentiation (AD) to seamlessly calculate derivatives and train the model. Meanwhile, it makes heavy use of Julia's language and compiler features to carry out code analysis and make optimisations. For example, Julia's GPU compilation support [@besard:2017] can be used to JIT-compile custom GPU kernels for model layers [@CuArrays].
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. As a result of this approach, it already supports several features not available in any other dynamic framework, such as kernel fusion [@Fusion], memory usage optimisations, importing of models via ONNX, and deployment of models to JavaScript for running in the browser.
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. This enables research into advanced compiler transforms such as batching [@Minibatch] without changing any user code.
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics.
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics. Many examples of such models can be found in the model zoo [@Zoo].
# References

View File

@ -7,22 +7,23 @@ module Flux
using Juno, Requires, Reexport
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, NADAM,
param, params, mapleaves, cpu, gpu
export Chain, Dense, RNN, LSTM, GRU, Conv,
Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu
@reexport using NNlib
using NNlib: @fix
include("tracker/Tracker.jl")
using .Tracker
export Tracker
import .Tracker: data
using .Tracker: data
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
include("utils.jl")
include("onehot.jl")
@ -32,12 +33,10 @@ include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalisation.jl")
include("layers/normalise.jl")
include("data/Data.jl")
include("jit/JIT.jl")
@require CuArrays include("cuda/cuda.jl")
end # module

View File

@ -4,11 +4,4 @@ using CuArrays
CuArrays.cudnn_available() && include("cudnn.jl")
import ..Flux.JIT: Shape, restructure
function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
buf = buf[1:sizeof(sh)]
reshape(reinterpret(T, buf), size(sh))
end
end

View File

@ -10,13 +10,15 @@ const cache_prefix = "https://cache.julialang.org"
function load()
suffixes = ["", ".phones", ".symbols"]
if isdir(deps("cmudict"))
if all(isfile.(["cmudict$x" for x in suffixes]))
if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes)
return
end
end
info("Downloading CMUDict dataset")
mkpath(deps("cmudict"))
for x in suffixes
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", deps("cmudict", "cmudict$x"))
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
deps("cmudict", "cmudict$x"))
end
end

View File

@ -14,6 +14,7 @@ function load()
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte"]
isfile(file) && continue
info("Downloading MNIST dataset")
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
open(file, "w") do io
write(io, GZip.open(read, "$file.gz"))

View File

@ -4,10 +4,10 @@ using ZipFile
using ..Data: deps
function load()
isfile(deps("sentiment.zip")) ||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
deps("sentiment.zip"))
return
isfile(deps("sentiment.zip")) || return
info("Downloading sentiment treebank dataset")
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
deps("sentiment.zip"))
end
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]

View File

@ -1,9 +0,0 @@
module JIT
using MacroTools
include("shapes.jl")
include("trace.jl")
include("lib.jl")
end

View File

@ -1,40 +0,0 @@
# Primitive definitions
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
Shape{T}(size(A,1))
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
Shape{T}(size(A,1),size(B,2))
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
A_mul_B!(C, A, B)
shape(::typeof(broadcast), f, xs...) =
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
shape(::typeof(reshape), x::Shape{T}, i...) where T =
Shape{T}(Base._reshape_uncolon(x, i))
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
# NNlib
using NNlib
using ..Tracker: _conv, _maxpool
shape(::typeof(softmax), x) = x
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
inplace!(::typeof(_conv), y, x, w, stride, pad) =
NNlib.conv!(y, x, w, stride = stride, pad = pad)
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
Shape{T}(NNlib.pdims(size(x), k, pad, k))
inplace!(::typeof(_maxpool), y, x, k, pad) =
NNlib.maxpool!(y, x, k, pad = pad)

View File

@ -1,56 +0,0 @@
using ..Tracker: TrackedArray
struct Shape{T,N}
dims::NTuple{N,Int}
end
VecShape{T} = Shape{T,1}
MatShape{T} = Shape{T,2}
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
Base.size(s::Shape) = s.dims
Base.size(s::Shape, n) = s.dims[n]
Base.ndims(s::Shape{T,N}) where {T,N} = N
Base.length(s::Shape) = prod(s.dims)
Base.eltype(s::Shape{T}) where T = T
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
function Base.show(io::IO, s::Shape{T}) where T
print(io, "Shape{$T}(")
join(io, s.dims, ", ")
print(io, ")")
end
shape(x) = x
shape(x::Shape) = x
shape(x::Tuple) = shape.(x)
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
shape(x::TrackedArray) = shape(x.data)
bytes(s::Shape) = sizeof(s)
bytes(x::Tuple) = sum(bytes.(x))
# Recover structure from byte buffers
# Make sure to hold on to the parent buffer for the lifetime of the data.
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
reshape(reinterpret(T, buf), size(sh))
end
# Execution with caches
mutable struct Cached{F,A}
f::F
buffer::A
end
function (c::Cached)(args...)
sh = shape(c.f, shape(args)...)
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
y = restructure(sh, c.buffer)
inplace!(c.f, y, args...)
end

View File

@ -1,75 +0,0 @@
# This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Tracker, DataFlow
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
function graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(y -> x === y, inputs)
cache[x] =
i > 0 ? inputnode(i) :
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
constant(x)
end
function trace(f, args...)
inputs = param.(args)
graph(f(inputs...), inputs...)
end
# Graph manipulation
function liftparams(v)
ps = []
v = prewalk(DataFlow.bumpinputs(v)) do v
isconstant(v) && istracked(v.value.value) || return v
push!(ps, v.value.value)
DataFlow.vcall(getindex, inputnode(1), length(ps))
end
return v, ps
end
function cacheall(v, buf = () -> UInt8[])
prewalk(v) do v
iscall(v) && isconstant(v[1]) || return v
f = v[1].value.value
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
end
end
code(v, n) = syntax(vertex(Lambda(n, v)))
struct Compiled{F,T<:Tuple}
model
func::F
params::T
end
# TODO when we support derivatives
# (c::Compiled)(args...) =
# Tracker.track(Tracker.Call(c, args...),
# c.func(Tracker.data.(c.params), args...))
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
function compile(f, args...)
v = trace(f, args...)
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
end
function source(f, args...)
v = trace(f, args...)
v, ps = liftparams(v)
code(v, length(args)+1) |> prettify
end

View File

@ -10,7 +10,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
σ::F
@ -18,17 +18,19 @@ struct Conv{N,F,A,V}
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where T =
Conv(σ, w, b, stride, pad)
stride = 1, pad = 0, dilation=1) where T =
Conv(σ, w, b, stride, pad, dilation)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
pad::NTuple{N,Integer} = map(_->0,k),
dilation::NTuple{N,Integer} = map(_->0,k)) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad)
stride = stride, pad = pad, dilation = dilation)
Flux.treelike(Conv)
@ -36,7 +38,7 @@ function (c::Conv)(x)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end
function Base.show(io::IO, l::Conv)
@ -45,6 +47,3 @@ function Base.show(io::IO, l::Conv)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
# v0.5
@deprecate Conv2D(args...; kw...) Conv(args...; kw...)

View File

@ -31,15 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x)
a.active || return x
y = similar(x)
rand!(y)
q = 1 - a.p
@inbounds for i=1:length(y)
y[i] = y[i] > a.p ? 1 / q : 0
end
return y .* x
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
end
_testmode!(a::Dropout, test) = (a.active = !test)
@ -68,70 +67,88 @@ function Base.show(io::IO, l::LayerNorm)
end
"""
BatchNorm(dims...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
BatchNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Batch Normalization Layer for [`Dense`](@ref) layer.
Batch Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`BatchNorm` computes the mean and variance for each each `W×H×1×N` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf)
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
In the example of MNIST,
in order to normalize the input of other layer,
put the `BatchNorm` layer before activation function.
Example:
```julia
m = Chain(
Dense(28^2, 64),
BatchNorm(64, λ = relu),
BatchNorm(64, relu),
Dense(64, 10),
BatchNorm(10),
softmax)
```
"""
mutable struct BatchNorm{F,V,N}
mutable struct BatchNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ # moving mean
σ # moving std
μ::W # moving mean
σ::W # moving std
ϵ::N
momentum::N
active::Bool
end
BatchNorm(dims::Integer...; λ = identity,
BatchNorm(chs::Integer, λ = identity;
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
function (BN::BatchNorm)(x)
λ, γ, β = BN.λ, BN.γ, BN.β
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end]
if !BN.active
μ = BN.μ
σ = BN.σ
μ = reshape(BN.μ, affine_shape...)
σ = reshape(BN.σ, affine_shape...)
else
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
m = size(x, 2) # batch size
μ = mean(x, 2)
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, axes)
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
end
λ.(γ .* ((x .- μ) ./ σ) .+ β)
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
end
end
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)

View File

@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
mse(, y) = sum(( .- y).^2)/length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return @fix -sum(y .* log.() .* weight) / size(y, 2)
@fix -sum(y .* log.() .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)

View File

@ -1,7 +1,8 @@
module Optimise
export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T}
x::T

View File

@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)

View File

@ -62,6 +62,18 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
end
end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
ut = zeros(p.x)
β1p = β1
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. ut = max(β2 * ut, abs(p.Δ))
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ

View File

@ -41,7 +41,7 @@ end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Use back!(x, Δ)")
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
# Fallthrough methods
@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(xs.data)
S = size(xs.data)
@ -108,20 +93,93 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
back(xs, Δ′)
end
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
Δ′ = similar(xs.data)
Δ′ .= 0
S = size(xs.data)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
back(xs, Δ′)
end
for f in [:vcat, :hcat]
@eval begin
# This section is a bit of a hack since julia doesn't have a standardised
# promotion mechanism for concatenation yet
# https://github.com/JuliaLang/julia/pull/20815
# It should support tracked concatenation with rank ∈ (1,2) with a
# TrackedArray anywhere among the arguments This works as long as base has
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
# It should support tracked concatenation with rank>2 if the TrackedArray is
# first
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
# It should support tracked concatenation with rank>2 if the TrackedArray is
# second
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row
end
end
function back(::typeof(vcat), Δ, xs...)
i = Base.tail(map(_ -> :, size(Δ)))
start = 0
for xsi in xs
i = map(_ -> :, size(xsi)) |> Base.tail
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
start += size(xsi, 1)
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
track(reshape, xs, dims...)
function back(::typeof(hcat), Δ, xs...)
start = 0
for xsi in xs
if ndims(xsi) == 1
@back(xsi, Δ[:, start+1])
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
end
start += size(xsi, 2)
end
end
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
track(reshape, xs, dims)
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
function back(::typeof(cat), Δ, dims, Xs...)
start = ntuple(i -> 0, Val{ndims(Δ)})
for xs in Xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
start = start .+ till_xs
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
@ -158,12 +216,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
Base.maximum(xs::TrackedArray) = track(maximum, xs)
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
Base.minimum(xs::TrackedArray) = track(minimum, xs)
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@ -186,6 +248,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
function back(::typeof(maximum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmax(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmax(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmin(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmin(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end
# BLAS
Base.diagm(x::TrackedVector) = track(diagm, x)
@ -247,35 +334,35 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
function back(::typeof(_conv), Δ, x, w, stride, pad)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
end
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_maxpool, x, k, pad)
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_maxpool, x, k, pad, stride)
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_meanpool, x, k, pad)
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_meanpool, x, k, pad, stride)
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
# Broadcasting

View File

@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
# This cuts derivatives, fix if needed.
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, xs), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back(::typeof(collect), Δ, xs)
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
end

View File

@ -21,6 +21,10 @@ cm = gpu(m)
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
# Fails in Pkg.test ffs
# c = gpu(Conv((2,2),3=>4))
# l = c(gpu(rand(10,10,3,2)))

View File

@ -1,12 +0,0 @@
using Flux, Base.Test
using Flux.JIT: compile
@testset "JIT" begin
m = Dense(10, 5)
f = compile(m, rand(10))
x = rand(10)
@test Tracker.data(m(x)) == f(x)
end

View File

@ -67,7 +67,7 @@ end
end
# with activation function
let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]')
let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]')
@test m.active
m(x)
@ -77,4 +77,22 @@ end
x = m(x).data
@test x[1] σ((1 - 0.3) / 1.1449489742783179)
end
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
end
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
end
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
end
end

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])

View File

@ -10,7 +10,6 @@ include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")
include("data.jl")
include("jit.jl")
if Base.find_in_path("CuArrays") nothing
include("cuda/cuda.jl")

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck
using Flux.Tracker: TrackedReal, gradcheck, grad
using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -29,17 +29,97 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> x', rand(5))
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(5), rand(3), rand(8))
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat]
r3 = f(A, B, param(C))
else
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end
r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
@test r0 == Flux.data(r4)
end
@testset "concat" begin
cat1(x...) = cat(1, x...)
cat2(x...) = cat(2, x...)
@testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8))
@test gradtest(vcatf, rand(5)', rand(5)')
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
@test gradtest(vcatf, rand(5), rand(3,1))
@test gradtest(vcatf, rand(5)', rand(2,5))
end
@testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
@test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2))
end
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
end
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(dim, x...)
@test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
end
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray)
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
promotiontest(fcat, rand(2), rand(2), rand(2))
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
end
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(kron,rand(5), rand(3))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron,rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@ -55,6 +135,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, 1), rand(2, 3))
@test gradtest(x -> maximum(x, 2), rand(2, 3))
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
end
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, 1), rand(2, 3))
@test gradtest(x -> minimum(x, 2), rand(2, 3))
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@ -82,6 +182,21 @@ end
@test param(2)^2 == 4.0
@testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2)
@test x isa TrackedArray
@test size(x) == (4,2)
x = reshape(param([1]), (1,:))
@test x isa TrackedArray
@test size(x) == (1,1)
x = reshape(param(rand(2)), (2,:))
@test x isa TrackedArray
@test size(x) == (2,1)
x = reshape(param(rand(2,2)), (1,:,2))
@test x isa TrackedArray
@test size(x) == (1,2,2)
end
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
@ -108,4 +223,13 @@ b = param(rand())
Tracker.back!(b)
@test Tracker.grad(b) == 1
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
end
end #testset