Merge branch 'master' into nadam-opt

This commit is contained in:
Tejan Karmali 2018-06-08 16:54:41 +05:30 committed by GitHub
commit 4a24b69976
No known key found for this signature in database
30 changed files with 603 additions and 326 deletions

View File

@ -1,6 +1,8 @@
# Флукс
<p align="center">
<img width="400px" src=""/>
[![Build Status](]( [![](]( [![](](
[![Build Status](]( [![](]( [![](]( [![DOI](](
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
@ -10,6 +12,18 @@ julia> Pkg.add("Flux")
See the [documentation]( or the [model zoo]( for examples.
If you use Flux in research, please cite the following paper:
author = {Mike Innes},
title = {Flux: Elegant Machine Learning with Julia},
journal = {Journal of Open Source Software},
year = {2018},
doi = {10.21105/joss.00602},
## Features
Flux has powerful high-level features, and common architectures can be defined in a few lines.
@ -36,7 +50,7 @@ b = param(randn(2))
y(x) = σ.(W * x .+ b)
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels with [CUDAnative](! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
function gpu_add(a, b, c)
@ -77,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc
Flux development is carried out via our [GitHub issues](, so feel free to open feature requests or PRs here.
For more informal discussions we'd love to have you on the [Julia slack](, where we hang out on the #machine-learning channel.
## Related Packages
Check out [Metalhead.jl]( for common computer vision datasets and trained models.
[MLDatasets.jl]( provides further common datasets.

View File

@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib],
"One-Hot Encoding" => "data/",
"GPU Support" => "",
"Saving & Loading" => "",
"Internals" =>
["Backpropagation" => "internals/"],
"Community" => ""])

View File

@ -0,0 +1,156 @@
# Flux.Tracker
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
julia> using Flux.Tracker
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
julia> W = param([1 2; 3 4])
Tracked 2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> x = param([5, 6])
Tracked 2-element Array{Float64,1}:
julia> y = W*x
Tracked 2-element Array{Float64,1}:
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
julia> Tracker.back!(y, [1, -1])
julia> W.grad
2×2 Array{Float64,2}:
5.0 6.0
-5.0 -6.0
julia> x.grad
2-element Array{Float64,1}:
## Internals
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
julia> x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
The `Tracker` stores the value and gradient of a given object, which we've seen before.
2-element Array{Float64,1}:
julia> x.tracker.grad
2-element Array{Float64,1}:
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
julia> Tracker.Call(+, 1, 2)
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
julia> y.tracker.f
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
Tracker.back(*, [1, -1], W, x)
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
julia> minus(a, b) = a - b
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus (generic function with 2 methods)
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
julia> a, b = param([6,5,4]), param([1,2,3])
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
julia> c = minus(a, b)
Tracked 3-element Array{Float64,1}:
julia> c.tracker.f
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
Finally, we have to specify the gradient of `minus`.
julia> Tracker.back(::typeof(minus), Δ, a, b) =
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
julia> Flux.back!(c, 1)
julia> a.grad
3-element Array{Float64,1}:
julia> b.grad
3-element Array{Float64,1}:
## Notes
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.

View File

@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks.
## Recurrent Layers
@ -15,6 +15,7 @@ Much like the core layers above, but can be used to process sequence data (as we

View File

@ -7,6 +7,7 @@ add the result to the overall loss.
For example, say we have a simple regression.
using Flux: crossentropy
m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y)

View File

@ -17,14 +17,6 @@
url = {},
author = {Mike Innes},
title = {Generic GPU Kernels},
year = 2017,
url = {},
urldate = {2018-02-16}
author = {Mike Innes and others},
title = {On Machine Learning and Programming Languages},
@ -33,10 +25,26 @@
urldate = {2018-02-16}
author = {Steven G. Johnson},
title = {More Dots: Syntactic Loop Fusion in Julia},
author = {Mike Innes and others},
title = {Generic GPU Kernels},
year = 2017,
url = {},
url = {},
urldate = {2018-02-16}
author = {Mike Innes and others},
title = {Flux Model Zoo},
year = 2018,
url = {},
urldate = {2018-02-16}
author = {James Bradbury},
title = {Minibatch.jl},
year = 2018,
url = {},
urldate = {2018-02-16}

View File

@ -24,8 +24,8 @@ bibliography: paper.bib
Flux is library for machine learning (ML), written using the numerical computing language Julia [@Julia]. The package allows models to be written using Julia's simple mathematical syntax, and applies automatic differentiation (AD) to seamlessly calculate derivatives and train the model. Meanwhile, it makes heavy use of Julia's language and compiler features to carry out code analysis and make optimisations. For example, Julia's GPU compilation support [@besard:2017] can be used to JIT-compile custom GPU kernels for model layers [@CuArrays].
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. As a result of this approach, it already supports several features not available in any other dynamic framework, such as kernel fusion [@Fusion], memory usage optimisations, importing of models via ONNX, and deployment of models to JavaScript for running in the browser.
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. This enables research into advanced compiler transforms such as batching [@Minibatch] without changing any user code.
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics.
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics. Many examples of such models can be found in the model zoo [@Zoo].
# References

View File

@ -7,22 +7,23 @@ module Flux
using Juno, Requires, Reexport
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, NADAM,
param, params, mapleaves, cpu, gpu
export Chain, Dense, RNN, LSTM, GRU, Conv,
Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu
@reexport using NNlib
using NNlib: @fix
using .Tracker
export Tracker
import .Tracker: data
using .Tracker: data
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
@ -32,12 +33,10 @@ include("layers/stateless.jl")
@require CuArrays include("cuda/cuda.jl")
end # module

View File

@ -4,11 +4,4 @@ using CuArrays
CuArrays.cudnn_available() && include("cudnn.jl")
import ..Flux.JIT: Shape, restructure
function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
buf = buf[1:sizeof(sh)]
reshape(reinterpret(T, buf), size(sh))

View File

@ -10,13 +10,15 @@ const cache_prefix = ""
function load()
suffixes = ["", ".phones", ".symbols"]
if isdir(deps("cmudict"))
if all(isfile.(["cmudict$x" for x in suffixes]))
if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes)
info("Downloading CMUDict dataset")
for x in suffixes
download("$cache_prefix/$version$x", deps("cmudict", "cmudict$x"))
deps("cmudict", "cmudict$x"))

View File

@ -14,6 +14,7 @@ function load()
isfile(file) && continue
info("Downloading MNIST dataset")
download("$file.gz", "$file.gz")
open(file, "w") do io
write(io,, "$file.gz"))

View File

@ -4,10 +4,10 @@ using ZipFile
using ..Data: deps
function load()
isfile(deps("")) ||
isfile(deps("")) || return
info("Downloading sentiment treebank dataset")
getfile(r, name) = r.files[findfirst(x -> == name, r.files)]

View File

@ -1,9 +0,0 @@
module JIT
using MacroTools

View File

@ -1,40 +0,0 @@
# Primitive definitions
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
A_mul_B!(C, A, B)
shape(::typeof(broadcast), f, xs...) =
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
shape(::typeof(reshape), x::Shape{T}, i...) where T =
Shape{T}(Base._reshape_uncolon(x, i))
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
# NNlib
using NNlib
using ..Tracker: _conv, _maxpool
shape(::typeof(softmax), x) = x
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
inplace!(::typeof(_conv), y, x, w, stride, pad) =
NNlib.conv!(y, x, w, stride = stride, pad = pad)
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
Shape{T}(NNlib.pdims(size(x), k, pad, k))
inplace!(::typeof(_maxpool), y, x, k, pad) =
NNlib.maxpool!(y, x, k, pad = pad)

View File

@ -1,56 +0,0 @@
using ..Tracker: TrackedArray
struct Shape{T,N}
VecShape{T} = Shape{T,1}
MatShape{T} = Shape{T,2}
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
Base.size(s::Shape) = s.dims
Base.size(s::Shape, n) = s.dims[n]
Base.ndims(s::Shape{T,N}) where {T,N} = N
Base.length(s::Shape) = prod(s.dims)
Base.eltype(s::Shape{T}) where T = T
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
function, s::Shape{T}) where T
print(io, "Shape{$T}(")
join(io, s.dims, ", ")
print(io, ")")
shape(x) = x
shape(x::Shape) = x
shape(x::Tuple) = shape.(x)
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
shape(x::TrackedArray) = shape(
bytes(s::Shape) = sizeof(s)
bytes(x::Tuple) = sum(bytes.(x))
# Recover structure from byte buffers
# Make sure to hold on to the parent buffer for the lifetime of the data.
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
reshape(reinterpret(T, buf), size(sh))
# Execution with caches
mutable struct Cached{F,A}
function (c::Cached)(args...)
sh = shape(c.f, shape(args)...)
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
y = restructure(sh, c.buffer)
inplace!(c.f, y, args...)

View File

@ -1,75 +0,0 @@
# This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Tracker, DataFlow
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
function graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(y -> x === y, inputs)
cache[x] =
i > 0 ? inputnode(i) :
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
function trace(f, args...)
inputs = param.(args)
graph(f(inputs...), inputs...)
# Graph manipulation
function liftparams(v)
ps = []
v = prewalk(DataFlow.bumpinputs(v)) do v
isconstant(v) && istracked(v.value.value) || return v
push!(ps, v.value.value)
DataFlow.vcall(getindex, inputnode(1), length(ps))
return v, ps
function cacheall(v, buf = () -> UInt8[])
prewalk(v) do v
iscall(v) && isconstant(v[1]) || return v
f = v[1].value.value
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
code(v, n) = syntax(vertex(Lambda(n, v)))
struct Compiled{F,T<:Tuple}
# TODO when we support derivatives
# (c::Compiled)(args...) =
# Tracker.track(Tracker.Call(c, args...),
# c.func(, args...))
(c::Compiled)(args...) = c.func(,, c::Compiled) = print(io, "Compiled(", c.model, ")")
function compile(f, args...)
v = trace(f, args...)
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
function source(f, args...)
v = trace(f, args...)
v, ps = liftparams(v)
code(v, length(args)+1) |> prettify

View File

@ -10,7 +10,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
Takes the keyword arguments `pad`, `stride` and `dilation`.
struct Conv{N,F,A,V}
@ -18,17 +18,19 @@ struct Conv{N,F,A,V}
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where T =
Conv(σ, w, b, stride, pad)
stride = 1, pad = 0, dilation=1) where T =
Conv(σ, w, b, stride, pad, dilation)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
pad::NTuple{N,Integer} = map(_->0,k),
dilation::NTuple{N,Integer} = map(_->0,k)) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad)
stride = stride, pad = pad, dilation = dilation)
@ -36,7 +38,7 @@ function (c::Conv)(x)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
function, l::Conv)
@ -45,6 +47,3 @@ function, l::Conv)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
# v0.5
@deprecate Conv2D(args...; kw...) Conv(args...; kw...)

View File

@ -31,15 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x) || return x
y = similar(x)
q = 1 - a.p
@inbounds for i=1:length(y)
y[i] = y[i] > a.p ? 1 / q : 0
return y .* x
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
_testmode!(a::Dropout, test) = ( = !test)
@ -68,70 +67,88 @@ function, l::LayerNorm)
BatchNorm(dims...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
BatchNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Batch Normalization Layer for [`Dense`](@ref) layer.
Batch Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`BatchNorm` computes the mean and variance for each each `W×H×1×N` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](
Internal Covariate Shift](
In the example of MNIST,
in order to normalize the input of other layer,
put the `BatchNorm` layer before activation function.
m = Chain(
Dense(28^2, 64),
BatchNorm(64, λ = relu),
BatchNorm(64, relu),
Dense(64, 10),
mutable struct BatchNorm{F,V,N}
mutable struct BatchNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ # moving mean
σ # moving std
μ::W # moving mean
σ::W # moving std
BatchNorm(dims::Integer...; λ = identity,
BatchNorm(chs::Integer, λ = identity;
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
function (BN::BatchNorm)(x)
λ, γ, β = BN.λ, BN.γ, BN.β
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end]
if !
μ = BN.μ
σ = BN.σ
μ = reshape(BN.μ, affine_shape...)
σ = reshape(BN.σ, affine_shape...)
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
m = size(x, 2) # batch size
μ = mean(x, 2)
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, axes)
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
λ.(γ .* ((x .- μ) ./ σ) .+ β)
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ,
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum,
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ,
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum,
_testmode!(BN::BatchNorm, test) = ( = !test)

View File

@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
mse(, y) = sum(( .- y).^2)/length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return @fix -sum(y .* log.() .* weight) / size(y, 2)
@fix -sum(y .* log.() .* weight) / size(y, 2)
@deprecate logloss(x, y) crossentropy(x, y)

View File

@ -1,7 +1,8 @@
module Optimise
export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T}

View File

@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax]( optimiser. Variant of ADAM based on
the -norm.
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)

View File

@ -62,6 +62,18 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
ut = zeros(p.x)
β1p = β1
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. ut = max(β2 * ut, abs(p.Δ))
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ

View File

@ -41,7 +41,7 @@ end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Use back!(x, Δ)")
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
# Fallthrough methods
@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(
S = size(
@ -108,20 +93,93 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
back(xs, Δ′)
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
Δ′ = similar(
Δ′ .= 0
S = size(
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
back(xs, Δ′)
for f in [:vcat, :hcat]
@eval begin
# This section is a bit of a hack since julia doesn't have a standardised
# promotion mechanism for concatenation yet
# It should support tracked concatenation with rank ∈ (1,2) with a
# TrackedArray anywhere among the arguments This works as long as base has
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
# It should support tracked concatenation with rank>2 if the TrackedArray is
# first
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
# It should support tracked concatenation with rank>2 if the TrackedArray is
# second
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row
function back(::typeof(vcat), Δ, xs...)
i = Base.tail(map(_ -> :, size(Δ)))
start = 0
for xsi in xs
i = map(_ -> :, size(xsi)) |> Base.tail
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
start += size(xsi, 1)
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
track(reshape, xs, dims...)
function back(::typeof(hcat), Δ, xs...)
start = 0
for xsi in xs
if ndims(xsi) == 1
@back(xsi, Δ[:, start+1])
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
start += size(xsi, 2)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
track(reshape, xs, dims), a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...), a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
function back(::typeof(cat), Δ, dims, Xs...)
start = ntuple(i -> 0, Val{ndims(Δ)})
for xs in Xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
start = start .+ till_xs
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
@ -158,12 +216,16 @@{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar( .= (prod(, dim...) ./ .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar( .= (reshape(.*(circshift.([reshape(, length(], 1:length(, size( .* Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(, args...)
Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
Base.maximum(xs::TrackedArray) = track(maximum, xs)
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
Base.minimum(xs::TrackedArray) = track(minimum, xs)
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region), ys::TrackedVector) = track(dot, xs, ys), ys::TrackedVector) = track(dot, xs, ys), ys::AbstractVector) = track(dot, xs, ys)
@ -186,6 +248,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar( .= Δ ./
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar( .= Δ ./ prod(size(, region...)))
function back(::typeof(maximum), Δ, xs::TrackedArray)
Δ′ = zeros(
_, i = findmax(
Δ′[i] = Δ
@back(xs, Δ′)
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
Δ′ = zeros(
_, is = findmax(, region)
Δ′[is] = Δ
@back(xs, Δ′)
function back(::typeof(minimum), Δ, xs::TrackedArray)
Δ′ = zeros(
_, i = findmin(
Δ′[i] = Δ
@back(xs, Δ′)
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
Δ′ = zeros(
_, is = findmin(, region)
Δ′[is] = Δ
@back(xs, Δ′)
Base.diagm(x::TrackedVector) = track(diagm, x)
@ -247,35 +334,35 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
function back(::typeof(_conv), Δ, x, w, stride, pad)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_maxpool, x, k, pad)
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_maxpool, x, k, pad, stride)
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_meanpool, x, k, pad)
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_meanpool, x, k, pad, stride)
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
# Broadcasting

View File

@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
TrackedReal(Tracked(x.tracker.f, convert(T,
# This cuts derivatives, fix if needed.
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
# TrackedReal(Tracked(x.tracker.f, convert(T,
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, xs), data.(xs))
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
function back(::typeof(collect), Δ, xs)
foreach((x, Δ) -> @back(x, Δ), xs, Δ)

View File

@ -21,6 +21,10 @@ cm = gpu(m)
@test all(p isa TrackedArray && isa CuArray for p in params(cm))
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
# Fails in Pkg.test ffs
# c = gpu(Conv((2,2),3=>4))
# l = c(gpu(rand(10,10,3,2)))

View File

@ -1,12 +0,0 @@
using Flux, Base.Test
using Flux.JIT: compile
@testset "JIT" begin
m = Dense(10, 5)
f = compile(m, rand(10))
x = rand(10)
@test == f(x)

View File

@ -67,7 +67,7 @@ end
# with activation function
let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]')
let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]')
@ -77,4 +77,22 @@ end
x = m(x).data
@test x[1] σ((1 - 0.3) / 1.1449489742783179)
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])

View File

@ -10,7 +10,6 @@ include("layers/normalisation.jl")
if Base.find_in_path("CuArrays") nothing

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck
using Flux.Tracker: TrackedReal, gradcheck, grad
using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -29,17 +29,97 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> x', rand(5))
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(5), rand(3), rand(8))
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat]
r3 = f(A, B, param(C))
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
@test r0 ==
@testset "concat" begin
cat1(x...) = cat(1, x...)
cat2(x...) = cat(2, x...)
@testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8))
@test gradtest(vcatf, rand(5)', rand(5)')
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
@test gradtest(vcatf, rand(5), rand(3,1))
@test gradtest(vcatf, rand(5)', rand(2,5))
@testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
@test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2))
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(dim, x...)
@test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray)
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
promotiontest(fcat, rand(2), rand(2), rand(2))
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(kron,rand(5), rand(3))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron,rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@ -55,6 +135,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, 1), rand(2, 3))
@test gradtest(x -> maximum(x, 2), rand(2, 3))
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, 1), rand(2, 3))
@test gradtest(x -> minimum(x, 2), rand(2, 3))
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@ -82,6 +182,21 @@ end
@test param(2)^2 == 4.0
@testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2)
@test x isa TrackedArray
@test size(x) == (4,2)
x = reshape(param([1]), (1,:))
@test x isa TrackedArray
@test size(x) == (1,1)
x = reshape(param(rand(2)), (2,:))
@test x isa TrackedArray
@test size(x) == (2,1)
x = reshape(param(rand(2,2)), (1,:,2))
@test x isa TrackedArray
@test size(x) == (1,2,2)
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
@ -108,4 +223,13 @@ b = param(rand())
@test Tracker.grad(b) == 1
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
@test grad.((x,y)) == (3, 2)
end #testset