commit
2ef4397cbb
20
README.md
20
README.md
@ -2,7 +2,7 @@
|
||||
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
|
||||
</p>
|
||||
|
||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/)
|
||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/) [](https://doi.org/10.21105/joss.00602)
|
||||
|
||||
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
||||
|
||||
@ -12,6 +12,18 @@ julia> Pkg.add("Flux")
|
||||
|
||||
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||
|
||||
If you use Flux in research, please cite the following paper:
|
||||
|
||||
```
|
||||
@article{innes:2018,
|
||||
author = {Mike Innes},
|
||||
title = {Flux: Elegant Machine Learning with Julia},
|
||||
journal = {Journal of Open Source Software},
|
||||
year = {2018},
|
||||
doi = {10.21105/joss.00602},
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
Flux has powerful high-level features, and common architectures can be defined in a few lines.
|
||||
@ -79,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc
|
||||
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
|
||||
|
||||
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
|
||||
|
||||
## Related Packages
|
||||
|
||||
Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
|
||||
|
||||
[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.
|
||||
|
@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib],
|
||||
"One-Hot Encoding" => "data/onehot.md",
|
||||
"GPU Support" => "gpu.md",
|
||||
"Saving & Loading" => "saving.md",
|
||||
"Internals" =>
|
||||
["Backpropagation" => "internals/tracker.md"],
|
||||
"Community" => "community.md"])
|
||||
|
||||
deploydocs(
|
||||
|
156
docs/src/internals/tracker.md
Normal file
156
docs/src/internals/tracker.md
Normal file
@ -0,0 +1,156 @@
|
||||
# Flux.Tracker
|
||||
|
||||
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
|
||||
|
||||
```julia
|
||||
julia> using Flux.Tracker
|
||||
```
|
||||
|
||||
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
|
||||
|
||||
```julia
|
||||
julia> W = param([1 2; 3 4])
|
||||
Tracked 2×2 Array{Float64,2}:
|
||||
1.0 2.0
|
||||
3.0 4.0
|
||||
|
||||
julia> x = param([5, 6])
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
5.0
|
||||
6.0
|
||||
|
||||
julia> y = W*x
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
17.0
|
||||
39.0
|
||||
```
|
||||
|
||||
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
|
||||
|
||||
```julia
|
||||
julia> Tracker.back!(y, [1, -1])
|
||||
|
||||
julia> W.grad
|
||||
2×2 Array{Float64,2}:
|
||||
5.0 6.0
|
||||
-5.0 -6.0
|
||||
|
||||
julia> x.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
## Internals
|
||||
|
||||
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
|
||||
|
||||
```julia
|
||||
julia> x.tracker
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
```
|
||||
|
||||
The `Tracker` stores the value and gradient of a given object, which we've seen before.
|
||||
|
||||
```julia
|
||||
julia> x.tracker.data
|
||||
2-element Array{Float64,1}:
|
||||
5.0
|
||||
6.0
|
||||
|
||||
julia> x.tracker.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
|
||||
|
||||
```julia
|
||||
julia> Tracker.Call(+, 1, 2)
|
||||
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
|
||||
```
|
||||
|
||||
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
|
||||
|
||||
```julia
|
||||
julia> y.tracker.f
|
||||
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
|
||||
```
|
||||
|
||||
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
|
||||
|
||||
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
|
||||
|
||||
```julia
|
||||
Tracker.back(*, [1, -1], W, x)
|
||||
```
|
||||
|
||||
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
|
||||
|
||||
## Custom Gradients
|
||||
|
||||
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
|
||||
|
||||
```julia
|
||||
julia> minus(a, b) = a - b
|
||||
```
|
||||
|
||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus (generic function with 2 methods)
|
||||
```
|
||||
|
||||
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
|
||||
|
||||
```julia
|
||||
julia> a, b = param([6,5,4]), param([1,2,3])
|
||||
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
|
||||
|
||||
julia> c = minus(a, b)
|
||||
Tracked 3-element Array{Float64,1}:
|
||||
5.0
|
||||
3.0
|
||||
1.0
|
||||
|
||||
julia> c.tracker.f
|
||||
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
|
||||
```
|
||||
|
||||
Finally, we have to specify the gradient of `minus`.
|
||||
|
||||
```julia
|
||||
julia> Tracker.back(::typeof(minus), Δ, a, b) =
|
||||
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
|
||||
```
|
||||
|
||||
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
|
||||
|
||||
```julia
|
||||
julia> Flux.back!(c, 1)
|
||||
|
||||
julia> a.grad
|
||||
3-element Array{Float64,1}:
|
||||
1.0
|
||||
1.0
|
||||
1.0
|
||||
|
||||
julia> b.grad
|
||||
3-element Array{Float64,1}:
|
||||
-1.0
|
||||
-1.0
|
||||
-1.0
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.
|
@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks.
|
||||
```@docs
|
||||
Chain
|
||||
Dense
|
||||
Conv2D
|
||||
Conv
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
@ -7,6 +7,7 @@ add the result to the overall loss.
|
||||
For example, say we have a simple regression.
|
||||
|
||||
```julia
|
||||
using Flux: crossentropy
|
||||
m = Dense(10, 5)
|
||||
loss(x, y) = crossentropy(softmax(m(x)), y)
|
||||
```
|
||||
|
13
src/Flux.jl
13
src/Flux.jl
@ -7,22 +7,23 @@ module Flux
|
||||
using Juno, Requires, Reexport
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||
param, params, mapleaves, cpu, gpu
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
using NNlib: @fix
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
export Tracker
|
||||
import .Tracker: data
|
||||
using .Tracker: data
|
||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
@ -10,7 +10,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct Conv{N,F,A,V}
|
||||
σ::F
|
||||
@ -18,17 +18,19 @@ struct Conv{N,F,A,V}
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0) where T =
|
||||
Conv(σ, w, b, stride, pad)
|
||||
stride = 1, pad = 0, dilation=1) where T =
|
||||
Conv(σ, w, b, stride, pad, dilation)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
pad::NTuple{N,Integer} = map(_->0,k),
|
||||
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad)
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
Flux.treelike(Conv)
|
||||
|
||||
@ -36,7 +38,7 @@ function (c::Conv)(x)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Conv)
|
||||
@ -45,6 +47,3 @@ function Base.show(io::IO, l::Conv)
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
# v0.5
|
||||
@deprecate Conv2D(args...; kw...) Conv(args...; kw...)
|
||||
|
@ -31,15 +31,14 @@ function Dropout(p)
|
||||
Dropout{typeof(p)}(p, true)
|
||||
end
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
q = 1 - a.p
|
||||
@inbounds for i=1:length(y)
|
||||
y[i] = y[i] > a.p ? 1 / q : 0
|
||||
end
|
||||
return y .* x
|
||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
return x .* y
|
||||
end
|
||||
|
||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||
|
@ -1,7 +1,8 @@
|
||||
module Optimise
|
||||
|
||||
export update!, params, train!,
|
||||
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||
export train!,
|
||||
SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
|
@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
||||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
the ∞-norm.
|
||||
"""
|
||||
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
||||
|
||||
@ -82,3 +91,12 @@ tuning.
|
||||
"""
|
||||
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
|
||||
than learning rate don't need tuning.
|
||||
"""
|
||||
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
@ -27,7 +27,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= η / (√acc + ϵ)
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
end
|
||||
end
|
||||
|
||||
@ -35,7 +35,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x) .+ ϵ
|
||||
function ()
|
||||
@. acc += p.Δ^2
|
||||
@. p.Δ *= η / √acc
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
end
|
||||
end
|
||||
|
||||
@ -56,12 +56,24 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||
@. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η
|
||||
@. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η
|
||||
β1p *= β1
|
||||
β2p *= β2
|
||||
end
|
||||
end
|
||||
|
||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
ut = zeros(p.x)
|
||||
β1p = β1
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. ut = max(β2 * ut, abs(p.Δ))
|
||||
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
|
||||
β1p *= β1
|
||||
end
|
||||
end
|
||||
|
||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x) .+ ϵ
|
||||
@ -74,6 +86,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
||||
end
|
||||
end
|
||||
|
||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η
|
||||
β1p *= β1
|
||||
β2p *= β2
|
||||
end
|
||||
end
|
||||
|
||||
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
||||
|
||||
function expdecay(p::Param, γ::Real)
|
||||
|
@ -41,7 +41,7 @@ end
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
back!(::TrackedArray) = error("Use back!(x, Δ)")
|
||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
|
||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
Δ′ = similar(xs.data)
|
||||
S = size(xs.data)
|
||||
@ -108,15 +93,90 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
|
||||
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
|
||||
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
|
||||
|
||||
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
|
||||
Δ′ = similar(xs.data)
|
||||
Δ′ .= 0
|
||||
S = size(xs.data)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
|
||||
for f in [:vcat, :hcat]
|
||||
@eval begin
|
||||
# This section is a bit of a hack since julia doesn't have a standardised
|
||||
# promotion mechanism for concatenation yet
|
||||
# https://github.com/JuliaLang/julia/pull/20815
|
||||
|
||||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||
# TrackedArray anywhere among the arguments This works as long as base has
|
||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# first
|
||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# second
|
||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
||||
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
i = Base.tail(map(_ -> :, size(Δ)))
|
||||
start = 0
|
||||
for xsi in xs
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||
start += size(xsi, 1)
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(hcat), Δ, xs...)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
if ndims(xsi) == 1
|
||||
@back(xsi, Δ[:, start+1])
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
end
|
||||
end
|
||||
|
||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
|
||||
function back(::typeof(cat), Δ, dims, Xs...)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
for xs in Xs
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
|
||||
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
|
||||
|
||||
start = start .+ till_xs
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
@ -156,12 +216,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
|
||||
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
@ -184,6 +248,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./
|
||||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmax(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmax(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmin(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmin(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
@ -245,18 +334,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||
|
||||
# TODO: can store kwargs efficiently in namedtuples
|
||||
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
|
||||
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_conv, x, w, stride, pad)
|
||||
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_conv, x, w, stride, pad)
|
||||
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_conv, x, w, stride, pad)
|
||||
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
|
||||
function back(::typeof(_conv), Δ, x, w, stride, pad)
|
||||
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
|
||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
|
||||
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
|
||||
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
end
|
||||
|
||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||
|
@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
# This cuts derivatives, fix if needed.
|
||||
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
|
||||
back(::typeof(getindex), Δ, t, i) =
|
||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, xs), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back(::typeof(collect), Δ, xs)
|
||||
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
|
||||
end
|
||||
|
@ -3,7 +3,7 @@ using Flux.Tracker
|
||||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt([w′])
|
||||
|
121
test/tracker.jl
121
test/tracker.jl
@ -1,5 +1,5 @@
|
||||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad
|
||||
using NNlib: conv
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
@ -29,17 +29,97 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@test gradtest(vcat, rand(5), rand(3))
|
||||
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
r2 = f(A, param(B), C)
|
||||
if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat]
|
||||
r3 = f(A, B, param(C))
|
||||
else
|
||||
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
|
||||
r3 = r2
|
||||
end
|
||||
r4 = f(param(A), param(B), param(C))
|
||||
|
||||
@test !isa(r0, TrackedArray)
|
||||
@test all(isa.([r1,r2,r3,r4], TrackedArray))
|
||||
@test r1 == r2 == r3 == r4
|
||||
@test r0 == Flux.data(r4)
|
||||
end
|
||||
|
||||
@testset "concat" begin
|
||||
cat1(x...) = cat(1, x...)
|
||||
cat2(x...) = cat(2, x...)
|
||||
|
||||
@testset for vcatf in [vcat, cat1]
|
||||
@test gradtest(vcatf, rand(5), rand(3))
|
||||
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcatf, rand(5)', rand(5)')
|
||||
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
||||
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
||||
@test gradtest(vcatf, rand(5), rand(3,1))
|
||||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||
end
|
||||
|
||||
@testset for hcatf in [hcat, cat2]
|
||||
@test gradtest(hcatf, rand(5), rand(5))
|
||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
||||
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
||||
@test gradtest(hcatf, rand(5)', rand(1,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@test gradtest(catf, rand(5))
|
||||
@test gradtest(catf, rand(5)')
|
||||
@test gradtest(catf, rand(2,5))
|
||||
@test gradtest(catf, rand(2,5,3))
|
||||
end
|
||||
|
||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 3:5
|
||||
catdim = (x...) -> cat(dim, x...)
|
||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||
end
|
||||
|
||||
@test !isa(vcat(rand(2)), TrackedArray)
|
||||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
|
||||
end
|
||||
|
||||
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
||||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@test gradtest(kron,rand(5), rand(3))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
@test gradtest(kron, rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
@test gradtest(kron,rand(5,1), rand(3,1))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@ -55,6 +135,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
|
||||
@ -123,4 +223,13 @@ b = param(rand())
|
||||
Tracker.back!(b)
|
||||
@test Tracker.grad(b) == 1
|
||||
|
||||
@testset "collect" begin
|
||||
x, y = param(2), param(3)
|
||||
xy = Tracker.collect([x, y])
|
||||
@test xy isa TrackedArray{Float64}
|
||||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
Loading…
Reference in New Issue
Block a user