Merge branch 'master' into depthwiseconv

This commit is contained in:
Avik Pal 2018-07-13 14:04:19 +05:30 committed by GitHub
commit 0aabf9d86b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 664 additions and 403 deletions

View File

@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i
```julia
model = Chain(
Dense(768, 128),
Dense(768, 128, σ),
LSTM(128, 256)
LSTM(256, 128)
Dense(128, 10),

View File

@ -1,5 +1,4 @@
julia 0.6.0
DataFlow 0.2.1
Juno
MacroTools 0.3.3
NNlib

View File

@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
```julia
using CuArrays

View File

@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly
```
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.

View File

@ -6,6 +6,52 @@ Backpropagation, or reverse-mode automatic differentiation, is handled by the `F
julia> using Flux.Tracker
```
Here we discuss some more advanced uses of this module, as well as covering its internals.
## Taking Gradients
In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function.
```julia
using Flux.Tracker
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
```
`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`.
```julia
using Flux.Tracker: forward
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
back(1) # (3.0 (tracked), 2.0 (tracked))
```
The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
```julia
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
julia> back([1,1,1])
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
```
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
```julia
a, b = param(2), param(3)
c = a*b # 6.0 (tracked)
Tracker.back!(c)
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
```
## Tracked Arrays
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
```julia
@ -41,7 +87,48 @@ julia> x.grad
-2.0
```
## Internals
You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
```julia
minus(a, b) = a - b
```
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
using Flux.Tracker: TrackedReal, track, @grad
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia
@grad function minus(a, b)
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
end
```
This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress).
Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this.
```julia
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
```
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
```
## Tracked Internals
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
@ -50,14 +137,9 @@ julia> x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
```
The `Tracker` stores the value and gradient of a given object, which we've seen before.
The `Tracker` stores the gradient of a given object, which we've seen before.
```julia
julia> x.tracker.data
2-element Array{Float64,1}:
5.0
6.0
julia> x.tracker.grad
2-element Array{Float64,1}:
-2.0
@ -86,71 +168,4 @@ When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forward
Tracker.back(*, [1, -1], W, x)
```
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
```julia
julia> minus(a, b) = a - b
```
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus (generic function with 2 methods)
```
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
```julia
julia> a, b = param([6,5,4]), param([1,2,3])
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
julia> c = minus(a, b)
Tracked 3-element Array{Float64,1}:
5.0
3.0
1.0
julia> c.tracker.f
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
```
Finally, we have to specify the gradient of `minus`.
```julia
julia> Tracker.back(::typeof(minus), Δ, a, b) =
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
```
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
```julia
julia> Flux.back!(c, 1)
julia> a.grad
3-element Array{Float64,1}:
1.0
1.0
1.0
julia> b.grad
3-element Array{Float64,1}:
-1.0
-1.0
-1.0
```
## Notes
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
```
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.
which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.

View File

@ -2,20 +2,74 @@
## Taking Gradients
Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. (It's a good idea to follow this example in the Julia repl.)
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
```julia
using Flux.Tracker
f(x) = 3x^2 + 2x + 1
# df/dx = 6x + 2
f(x) = Tracker.gradient(f, x)[1]
f(2) # 14.0 (tracked)
# d²f/dx² = 6
f(x) = Tracker.gradient(f, x)[1]
f(2) # 6.0 (tracked)
```
(We'll learn more about why these numbers show up as `(tracked)` below.)
When a function has many parameters, we can pass them all in explicitly:
```julia
f(W, b, x) = W * x + b
Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0, 2.0 (tracked))
```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
```julia
W = param(2) # 2.0 (tracked)
b = param(3) # 3.0 (tracked)
f(x) = W * x + b
params = Params([W, b])
grads = Tracker.gradient(() -> f(4), params)
grads[W] # 4.0
grads[b] # 1.0
```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
## Simple Models
Consider a simple linear regression, which tries to predict an output array `y` from an input `x`.
```julia
W = rand(2, 5)
b = rand(2)
predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)
function loss(x, y)
ŷ = predict(x)
sum((y .- ŷ).^2)
end
x, y = rand(5), rand(2) # Dummy data
loss(x, y) # ~ 3
```
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above.
```julia
using Flux.Tracker
@ -23,25 +77,25 @@ using Flux.Tracker
W = param(W)
b = param(b)
l = loss(x, y)
back!(l)
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
```
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
```julia
W.grad
using Flux.Tracker: update!
# Update the parameter
W.data .-= 0.1(W.grad)
Δ = gs[W]
# Update the parameter and reset the gradient
update!(W, -0.1Δ)
loss(x, y) # ~ 2.5
```
The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md).
All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different they might have millions of parameters or complex control flow, and there are ways to manage this complexity. Let's see what that looks like.
All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different they might have millions of parameters or complex control flow. Let's see how Flux handles more complex models.
## Building Layers

View File

@ -103,7 +103,7 @@ m.(seq)
## Truncating Gradients
By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive.
By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive.
To avoid this we can *truncate* the gradient calculation, forgetting the history.

View File

@ -44,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
loss(rand(28^2), rand(10))
```
One can also easily add per-layer regularisation via the `activations` function:
```julia
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
julia> activations(c, rand(10))
3-element Array{Any,1}:
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
param([0.0330606, -0.456104])
param([0.61991, 0.38009])
julia> sum(vecnorm, ans)
2.639678767773633 (tracked)
```

View File

@ -3,6 +3,8 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
```julia
using Flux.Tracker
W = param(rand(2, 5))
b = param(rand(2))
@ -11,22 +13,25 @@ loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3
back!(l)
params = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params)
```
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia
function update()
using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate
for p in (W, b)
p.data .-= η .* p.grad # Apply the update
p.grad .= 0 # Clear the gradient
update!(p, -η * grads[p])
end
end
```
If we call `update`, the parameters `W` and `b` will change and our loss should go down.
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.

View File

@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
include("utils.jl")

View File

@ -286,41 +286,28 @@ function desc(rnn)
return d
end
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
mutable struct RNNCall{R}
rnn::R
reserve::CuVector{UInt8}
RNNCall{R}(rnn::R) where R = new(rnn)
end
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
function (c::RNNCall)(args...)
rs, result = forwardTrain(desc(c.rnn), args...)
c.reserve = rs
return result
end
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h[1], h[2]) :
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
@ -329,44 +316,29 @@ end
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
function accum_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] += src[reverse(I)...]
return
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
return dst
end
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
y, ho = y_
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
@back(x, dx)
@back(h, unbroadcast(h, dh))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
# We don't have to make this assumption, it's just slightly more complex.
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
end
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
y, ho, co = y_
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
@back(x, dx)
@back(h, unbroadcast(h, dh))
@back(c, unbroadcast(h, dc))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
dWi.', dWh.', db))
end
end

View File

@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain)
print(io, ")")
end
# Seem to need this for `accumulate`; try removing on 0.7
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
"""
Dense(in::Integer, out::Integer, σ = identity)

View File

@ -1,5 +1,10 @@
using NNlib: conv, depthwiseconv
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
"""
Conv(size, in=>out)
Conv(size, in=>out, relu)
@ -21,14 +26,12 @@ struct Conv{N,F,A,V}
dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation=1) where T =
Conv(σ, w, b, stride, pad, dilation)
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k),
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
stride = 1, pad = 0, dilation = 1) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)

View File

@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
end
"""
binarycrossentropy(, y)
binarycrossentropy(, y; ϵ=eps())
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}:
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
0.352317
0.86167
"""
binarycrossentropy(, y) = -y*log() - (1 - y)*log(1 - )
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
"""
logitbinarycrossentropy(logŷ, y)

View File

@ -1,7 +1,7 @@
module Optimise
export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T}

View File

@ -1,7 +1,7 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
# p.Δ at the end of the weights update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)

View File

@ -5,6 +5,14 @@ function descent(p::Param, η::Real)
end
end
# Ref: https://arxiv.org/abs/1711.05101.pdf
function descentweightdecay(p::Param, η::Real, γ::Real)
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
end
end
function momentum(p::Param, ρ, η)
v = zeros(p.x)
function ()

View File

@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
@interrupts back!(l)
opt()
cb() == :stop && break

View File

@ -1,22 +1,27 @@
module Tracker
using MacroTools
using MacroTools: @q, @forward
import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
tracker(x) = nothing
istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x))
grad(::Void) = nothing
data(x) = x
struct Call{F,As<:Tuple}
func::F
args::As
end
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
Call() = Call(nothing, ())
# When deserialising, the object_id changes
a::Call == b::Call = a.func == b.func && a.args == b.args
@ -27,33 +32,83 @@ mutable struct Tracked{T}
ref::UInt32
f::Call
isleaf::Bool
data::T
grad::T
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
end
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
track(f::Call, x) = Tracked(f, x)
track(f::Call) = track(f, f())
track(f, xs...) = track(Call(f, xs...))
istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data
isleaf(x::Tracked) = x.f == Call()
grad(x::Tracked) = x.grad
track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end
function track(f, xs...; kw...)
y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y)
end
macro grad(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {T__} = body_)) || error("Need a function definition")
T == nothing && (T = [])
isexpr(name, :(::)) || (name = :(::typeof($name)))
insert!(args, 1+isexpr(args[1], :parameters) , name)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
function update!(x, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
include("idset.jl")
include("back.jl")
include("scalar.jl")
include("array.jl")
include("numeric.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
"""
checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models.
"""
checkpoint(f, args...) = track(checkpoint, f, args...)
@grad function checkpoint(f, args...)
data(f(args...)), function (Δ)
y, back = forward(f, args...)
(nothing, back(Δ)...)
end
end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
@grad identity(x) = data(x), Δ -> (Δ,)
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import NNlib.cudata
import Adapt.adapt

View File

@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end
data(x::TrackedArray) = x.data
tracker(x::TrackedArray) = x.tracker
TrackedVector{T,A} = TrackedArray{T,1,A}
@ -15,12 +16,12 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
@ -49,6 +50,9 @@ for f in :[Base.size, Base.ndims].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
size(data(x), i, j, is...)
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
@ -62,54 +66,57 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs.data)
Δ′[i...] = Δ
@back(xs, Δ′)
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′[i...] = data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs)
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
@grad -(xs) = -data(xs), Δ -> (-Δ,)
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(xs.data)
S = size(xs.data)
for (i,v) in enumerate(Δ)
@grad function repmat(xs, m, n = 1)
repmat(data(xs), m, n), function (Δ)
Δ′ = similar(xs)
S = size(xs)
for (i,v) in enumerate(data(Δ))
d1 = divrem(i-1, S[1]*m)
x = d1[2] % S[1]+1
y = d1[1] % S[2]+1
Δ′[x, y] += v
end
back(xs, Δ′)
return (nobacksies(:repmat, Δ′), nothing, nothing)
end
end
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
repeat(data(xs), inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs)
S = size(xs)
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
Δ′ = similar(xs.data)
Δ′ .= 0
S = size(xs.data)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
back(xs, Δ′)
(nobacksies(:repeat, Δ′),)
end
end
@ -138,42 +145,51 @@ for f in [:vcat, :hcat]
end
end
function back(::typeof(vcat), Δ, xs...)
start = 0
for xsi in xs
i = map(_ -> :, size(xsi)) |> Base.tail
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
start += size(xsi, 1)
@grad function vcat(xs...)
vcat(data.(xs)...), function (Δ)
start = 0
Δs = [begin
i = map(_ -> :, size(xsi)) |> Base.tail
d = Δ[start+1:start+size(xsi,1), i...]
start += size(xsi, 1)
d
end for xsi in xs]
return (Δs...,)
end
end
function back(::typeof(hcat), Δ, xs...)
start = 0
for xsi in xs
if ndims(xsi) == 1
@back(xsi, Δ[:, start+1])
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
end
start += size(xsi, 2)
@grad function hcat(xs...)
hcat(data.(xs)...), function (Δ)
start = 0
Δs = [begin
d = if ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Δ[:, start+1:start+size(xsi,2), i...]
end
start += size(xsi, 2)
d
end for xsi in xs]
return (Δs...,)
end
end
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
function back(::typeof(cat), Δ, dims, Xs...)
start = ntuple(i -> 0, Val{ndims(Δ)})
for xs in Xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
start = start .+ till_xs
@grad function cat(dims, Xs...)
cat(dims, data.(Xs)...), function (Δ)
start = ntuple(i -> 0, Val{ndims(Δ)})
Δs = [begin
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
d
end for xs in Xs]
return (nothing, Δs...,)
end
end
@ -181,11 +197,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims)))
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
@ -207,14 +222,18 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
Base.sum(xs::TrackedArray) = track(sum, xs)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
@grad sum(xs, dim...) = sum(data(xs), dim...),
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
@grad prod(xs, dim) = prod(data(xs), dim),
Δ -> (nobacksies(:sum,
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
nothing)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
@ -230,10 +249,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
function back(::typeof(dot), Δ, xs, ys)
@back(xs, Δ.*data(ys))
@back(ys, Δ.*data(xs))
end
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
@ -244,41 +260,32 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
Base.vecnorm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
function back(::typeof(maximum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmax(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
@grad function maximum(xs, r...)
maximum(data(xs), r...), function (Δ)
Δ′ = zero(xs)
_, i = findmax(data(xs), r...)
Δ′[i] = data(Δ)
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
end
end
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmax(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmin(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmin(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
@grad function minimum(xs, r...)
minimum(data(xs), r...), function (Δ)
Δ′ = zero(xs)
_, i = findmin(data(xs), r...)
Δ′[i] = data(Δ)
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
end
end
# BLAS
Base.diagm(x::TrackedVector) = track(diagm, x)
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
for f in :[*, Ac_mul_B, A_mul_Bc].args
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
@eval begin
import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
@ -295,30 +302,14 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args
end
end
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(a, A_mul_Bt(Δ, data(b)))
@back(b, At_mul_B(data(a), Δ))
end
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))')
@back(b, data(a)*Δ)
end
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, Δ * data(b))
@back(b, At_mul_B(data(a), Δ)')
end
# Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W)
W.grad .+= Δ .* data(x).'
else
back(W, A_mul_Bt(Δ, data(x)))
end
@back(x, At_mul_B(data(W), Δ))
end
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
# NNlib
@ -327,26 +318,11 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwisecon
softmax(xs::TrackedArray) = track(softmax, xs)
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
end
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
_depthwiseconv(x, w, stride, pad) = depthwiseconv(x, w, stride = stride, pad = pad)
@ -362,61 +338,63 @@ function back(::typeof(_depthwiseconv), Δ, x, w, stride, pad)
@back(w, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
end
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_maxpool, x, k, pad, stride)
@grad conv(x, w; kw...) =
conv(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
@grad function maxpool(x, k; kw...)
y = maxpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_meanpool, x, k, pad, stride)
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
@grad function meanpool(x, k; kw...)
y = meanpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
# Broadcasting
using ForwardDiff: Dual, partials
struct Broadcasted{F,T}
f::F
data::T
end
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
using ForwardDiff: Dual, partials, value
dualify(xs, n) = xs
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
dualify(xs::Real, ps) = Dual(xs, ps)
function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
b = Broadcasted(f, out)
track(Call(b, args...), b())
end
unbroadcast(x::Tuple, Δ) =
x == size(Δ) ? Δ :
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
function getpartial(Δ, x, i)
@inbounds p = getindex(partials(x), i)
return Δ * p
end
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
function ∇broadcast(f, args::Vararg{Any,N}) where N
sizes = size.(args)
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
y = value.(out)
back = function (Δ_)
Δ = data(Δ_)
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
nobacksies(:broadcast, dxs)
end
# So we can return non-tracked arrays
track(Call(back, tracker.(args)), y)
end
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
@ -429,4 +407,4 @@ Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = broadcast(f, A, Bs...)

View File

@ -10,8 +10,6 @@ function scan(x::Tracked)
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
else
isdefined(x, :grad) || (x.grad = init_grad(x.data))
end
return
end
@ -21,9 +19,14 @@ function scan(x)
return
end
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back_(c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs))
end
back_(::Call{Void}, Δ) = nothing
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
@ -31,33 +34,121 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
ref == 0 && back_(x.f, x.data, x.grad)
if ref > 0 || isdefined(x, :grad)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
else
x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
else
ref == 0 && back_(x.f, x.data, Δ)
ref == 0 && back_(x.f, Δ)
end
return
end
back(x, Δ) = back(tracker(x), Δ)
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
macro back(x, Δ)
quote
x = $(esc(x))
istracked(x) && back(x, $(esc(Δ)))
end
end
back(::Void, _) = return
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x::Tracked, Δ)
function back!(x, Δ)
istracked(x) || return
scan(x)
back(x, Δ)
back(tracker(x), Δ)
return
end
back!(x, Δ) = back!(tracker(x), Δ)
# Out-of-place gradients
struct Params
params::IdSet
Params(xs) = new(IdSet(xs))
end
@forward Params.params Base.start, Base.next, Base.done
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.params, ", ")
print(io, "])")
end
struct Grads
grads::ObjectIdDict
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(ObjectIdDict())
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
@forward Grads.grads Base.setindex!, Base.haskey
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end
back_(g::Grads, ::Call{Void}, Δ) = nothing
function back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
return
end
back(::Grads, ::Void, _) = return
function forward(f, ps::Params)
y = f()
y, function (Δ)
g = Grads(ps)
if istracked(y)
scan(y)
back(g, tracker(y), Δ)
end
return g
end
end
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(back(Δ), args)
end
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
function gradient(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
derivative(f, x) = gradient(f, x)[1]

25
src/tracker/idset.jl Normal file
View File

@ -0,0 +1,25 @@
struct IdSet{T} <: AbstractSet{T}
dict::ObjectIdDict
IdSet{T}() where T = new(ObjectIdDict())
end
Base.eltype{T}(::IdSet{T}) = T
IdSet() = IdSet{Any}()
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
IdSet(xs) = IdSet{eltype(xs)}(xs)
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
@forward IdSet.dict Base.length
Base.start(s::IdSet) = start(keys(s.dict))
Base.next(s::IdSet, st) = next(keys(s.dict), st)
Base.done(s::IdSet, st) = done(keys(s.dict), st)

View File

@ -1,9 +1,3 @@
function gradient(f, xs::AbstractArray...)
xs = param.(xs)
back!(f(xs...))
grad.(xs)
end
function ngradient(f, xs::AbstractArray...)
grads = zeros.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
@ -21,4 +15,4 @@ end
gradcheck(f, xs...) =
all(isapprox.(ngradient(f, xs...),
gradient(f, xs...), rtol = 1e-5, atol = 1e-5))
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))

View File

@ -1,14 +1,20 @@
struct TrackedReal{T<:Real} <: Real
data::T
tracker::Tracked{T}
end
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
data(x::TrackedReal) = x.data
tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
back!(x::TrackedReal) = back!(x, 1)
function back!(x::TrackedReal)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
end
function Base.show(io::IO, x::TrackedReal)
show(io, data(x))
@ -19,15 +25,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
# This cuts derivatives, fix if needed.
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
Base.eps(x::TrackedReal) = eps(data(x))
for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end
@ -42,23 +49,21 @@ using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
@eval begin
@grad $M.$f(a::Real) =
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
$M.$f(a::TrackedReal) = track($M.$f, a)
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
end
end
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
da, db = DiffRules.diffrule(M, f, :a, :b)
f = :($M.$f)
@eval begin
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
@back(a, Δ * $da)
@back(b, Δ * $db)
end
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b)
$f(a::Real, b::TrackedReal) = track($f, a, b)
end
end
@ -70,16 +75,18 @@ import Base:^
# Tuples
struct TrackedTuple{T<:Tuple}
data::T
tracker::Tracked{T}
end
data(xs::TrackedTuple) = xs.data
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
@ -90,20 +97,21 @@ Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, xs), data.(xs))
track(Call(collect, (tracker.(xs),)), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back(::typeof(collect), Δ, xs)
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
function back_(c::Call{typeof(collect)}, Δ)
foreach(back, c.args[1], data(Δ))
end

View File

@ -1,4 +1,5 @@
import Adapt: adapt
import .Tracker: IdSet
children(x) = ()
mapchildren(f, x) = x
@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict())
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
using DataFlow: OSet
function prefor(f, x; seen = OSet())
function prefor(f, x; seen = IdSet())
x seen && return
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))

View File

@ -1,7 +1,9 @@
using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy
const ϵ = 1e-7
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
end
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y)
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end
end

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv, depthwiseconv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -111,6 +111,7 @@ end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
# TODO unreliable
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
@ -234,4 +235,24 @@ Tracker.back!(b)
@test grad.((x,y)) == (3, 2)
end
# Gradient Hooks
@testset "Hooks" begin
x = param(2)
y = Tracker.hook(-, x)
back!(y)
@test grad(x) == -1
end
@testset "Checkpointing" begin
count = 0
function mul(a, b)
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test count == 3
end
end #testset