Merge branch 'master' into depthwiseconv
This commit is contained in:
commit
0aabf9d86b
|
@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i
|
|||
|
||||
```julia
|
||||
model = Chain(
|
||||
Dense(768, 128),
|
||||
Dense(768, 128, σ),
|
||||
LSTM(128, 256)
|
||||
LSTM(256, 128)
|
||||
Dense(128, 10),
|
||||
|
|
|
@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided
|
|||
|
||||
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
||||
|
||||
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
|
||||
|
||||
```julia
|
||||
using CuArrays
|
||||
|
||||
|
|
|
@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly
|
|||
```
|
||||
|
||||
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
|
||||
|
||||
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.
|
||||
|
|
|
@ -6,6 +6,52 @@ Backpropagation, or reverse-mode automatic differentiation, is handled by the `F
|
|||
julia> using Flux.Tracker
|
||||
```
|
||||
|
||||
Here we discuss some more advanced uses of this module, as well as covering its internals.
|
||||
|
||||
## Taking Gradients
|
||||
|
||||
In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
|
||||
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: forward
|
||||
|
||||
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
|
||||
|
||||
back(1) # (3.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
|
||||
|
||||
```julia
|
||||
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
|
||||
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
|
||||
|
||||
julia> back([1,1,1])
|
||||
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
|
||||
```
|
||||
|
||||
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
|
||||
|
||||
```julia
|
||||
a, b = param(2), param(3)
|
||||
|
||||
c = a*b # 6.0 (tracked)
|
||||
|
||||
Tracker.back!(c)
|
||||
|
||||
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
|
||||
```
|
||||
|
||||
## Tracked Arrays
|
||||
|
||||
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
|
||||
|
||||
```julia
|
||||
|
@ -41,7 +87,48 @@ julia> x.grad
|
|||
-2.0
|
||||
```
|
||||
|
||||
## Internals
|
||||
You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`.
|
||||
|
||||
## Custom Gradients
|
||||
|
||||
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
|
||||
|
||||
```julia
|
||||
minus(a, b) = a - b
|
||||
```
|
||||
|
||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: TrackedReal, track, @grad
|
||||
|
||||
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
||||
|
||||
```julia
|
||||
@grad function minus(a, b)
|
||||
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
|
||||
end
|
||||
```
|
||||
|
||||
This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress).
|
||||
|
||||
Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this.
|
||||
|
||||
```julia
|
||||
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
||||
```
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
## Tracked Internals
|
||||
|
||||
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
|
||||
|
||||
|
@ -50,14 +137,9 @@ julia> x.tracker
|
|||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
```
|
||||
|
||||
The `Tracker` stores the value and gradient of a given object, which we've seen before.
|
||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
||||
|
||||
```julia
|
||||
julia> x.tracker.data
|
||||
2-element Array{Float64,1}:
|
||||
5.0
|
||||
6.0
|
||||
|
||||
julia> x.tracker.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
|
@ -86,71 +168,4 @@ When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forward
|
|||
Tracker.back(*, [1, -1], W, x)
|
||||
```
|
||||
|
||||
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
|
||||
|
||||
## Custom Gradients
|
||||
|
||||
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
|
||||
|
||||
```julia
|
||||
julia> minus(a, b) = a - b
|
||||
```
|
||||
|
||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus (generic function with 2 methods)
|
||||
```
|
||||
|
||||
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
|
||||
|
||||
```julia
|
||||
julia> a, b = param([6,5,4]), param([1,2,3])
|
||||
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
|
||||
|
||||
julia> c = minus(a, b)
|
||||
Tracked 3-element Array{Float64,1}:
|
||||
5.0
|
||||
3.0
|
||||
1.0
|
||||
|
||||
julia> c.tracker.f
|
||||
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
|
||||
```
|
||||
|
||||
Finally, we have to specify the gradient of `minus`.
|
||||
|
||||
```julia
|
||||
julia> Tracker.back(::typeof(minus), Δ, a, b) =
|
||||
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
|
||||
```
|
||||
|
||||
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
|
||||
|
||||
```julia
|
||||
julia> Flux.back!(c, 1)
|
||||
|
||||
julia> a.grad
|
||||
3-element Array{Float64,1}:
|
||||
1.0
|
||||
1.0
|
||||
1.0
|
||||
|
||||
julia> b.grad
|
||||
3-element Array{Float64,1}:
|
||||
-1.0
|
||||
-1.0
|
||||
-1.0
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.
|
||||
which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
|
||||
|
|
|
@ -2,20 +2,74 @@
|
|||
|
||||
## Taking Gradients
|
||||
|
||||
Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. (It's a good idea to follow this example in the Julia repl.)
|
||||
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
|
||||
f(x) = 3x^2 + 2x + 1
|
||||
|
||||
# df/dx = 6x + 2
|
||||
f′(x) = Tracker.gradient(f, x)[1]
|
||||
|
||||
f′(2) # 14.0 (tracked)
|
||||
|
||||
# d²f/dx² = 6
|
||||
f′′(x) = Tracker.gradient(f′, x)[1]
|
||||
|
||||
f′′(2) # 6.0 (tracked)
|
||||
```
|
||||
|
||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||
|
||||
When a function has many parameters, we can pass them all in explicitly:
|
||||
|
||||
```julia
|
||||
f(W, b, x) = W * x + b
|
||||
|
||||
Tracker.gradient(f, 2, 3, 4)
|
||||
(4.0 (tracked), 1.0, 2.0 (tracked))
|
||||
```
|
||||
|
||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
|
||||
|
||||
```julia
|
||||
W = param(2) # 2.0 (tracked)
|
||||
b = param(3) # 3.0 (tracked)
|
||||
|
||||
f(x) = W * x + b
|
||||
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> f(4), params)
|
||||
|
||||
grads[W] # 4.0
|
||||
grads[b] # 1.0
|
||||
```
|
||||
|
||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
|
||||
|
||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||
|
||||
## Simple Models
|
||||
|
||||
Consider a simple linear regression, which tries to predict an output array `y` from an input `x`.
|
||||
|
||||
```julia
|
||||
W = rand(2, 5)
|
||||
b = rand(2)
|
||||
|
||||
predict(x) = W*x .+ b
|
||||
loss(x, y) = sum((predict(x) .- y).^2)
|
||||
|
||||
function loss(x, y)
|
||||
ŷ = predict(x)
|
||||
sum((y .- ŷ).^2)
|
||||
end
|
||||
|
||||
x, y = rand(5), rand(2) # Dummy data
|
||||
loss(x, y) # ~ 3
|
||||
```
|
||||
|
||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
|
||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
|
@ -23,25 +77,25 @@ using Flux.Tracker
|
|||
W = param(W)
|
||||
b = param(b)
|
||||
|
||||
l = loss(x, y)
|
||||
|
||||
back!(l)
|
||||
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
|
||||
```
|
||||
|
||||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
||||
|
||||
```julia
|
||||
W.grad
|
||||
using Flux.Tracker: update!
|
||||
|
||||
# Update the parameter
|
||||
W.data .-= 0.1(W.grad)
|
||||
Δ = gs[W]
|
||||
|
||||
# Update the parameter and reset the gradient
|
||||
update!(W, -0.1Δ)
|
||||
|
||||
loss(x, y) # ~ 2.5
|
||||
```
|
||||
|
||||
The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md).
|
||||
|
||||
All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different – they might have millions of parameters or complex control flow, and there are ways to manage this complexity. Let's see what that looks like.
|
||||
All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different – they might have millions of parameters or complex control flow. Let's see how Flux handles more complex models.
|
||||
|
||||
## Building Layers
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ m.(seq)
|
|||
|
||||
## Truncating Gradients
|
||||
|
||||
By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.
|
||||
By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.
|
||||
|
||||
To avoid this we can *truncate* the gradient calculation, forgetting the history.
|
||||
|
||||
|
|
|
@ -44,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
|
|||
|
||||
loss(rand(28^2), rand(10))
|
||||
```
|
||||
|
||||
One can also easily add per-layer regularisation via the `activations` function:
|
||||
|
||||
```julia
|
||||
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
|
||||
|
||||
julia> activations(c, rand(10))
|
||||
3-element Array{Any,1}:
|
||||
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
|
||||
param([0.0330606, -0.456104])
|
||||
param([0.61991, 0.38009])
|
||||
|
||||
julia> sum(vecnorm, ans)
|
||||
2.639678767773633 (tracked)
|
||||
```
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
|
||||
W = param(rand(2, 5))
|
||||
b = param(rand(2))
|
||||
|
||||
|
@ -11,22 +13,25 @@ loss(x, y) = sum((predict(x) .- y).^2)
|
|||
|
||||
x, y = rand(5), rand(2) # Dummy data
|
||||
l = loss(x, y) # ~ 3
|
||||
back!(l)
|
||||
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> loss(x, y), params)
|
||||
```
|
||||
|
||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
||||
```julia
|
||||
function update()
|
||||
using Flux.Tracker: grad, update!
|
||||
|
||||
function sgd()
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
p.data .-= η .* p.grad # Apply the update
|
||||
p.grad .= 0 # Clear the gradient
|
||||
update!(p, -η * grads[p])
|
||||
end
|
||||
end
|
||||
```
|
||||
|
||||
If we call `update`, the parameters `W` and `b` will change and our loss should go down.
|
||||
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
|
||||
|
||||
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
|||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
|
||||
include("utils.jl")
|
||||
|
|
|
@ -286,41 +286,28 @@ function desc(rnn)
|
|||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||
|
||||
mutable struct RNNCall{R}
|
||||
rnn::R
|
||||
reserve::CuVector{UInt8}
|
||||
RNNCall{R}(rnn::R) where R = new(rnn)
|
||||
end
|
||||
|
||||
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
|
||||
|
||||
function (c::RNNCall)(args...)
|
||||
rs, result = forwardTrain(desc(c.rnn), args...)
|
||||
c.reserve = rs
|
||||
return result
|
||||
end
|
||||
import Flux.Tracker
|
||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
||||
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h[1], h[2]) :
|
||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
end
|
||||
|
@ -329,44 +316,29 @@ end
|
|||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
function accum_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] += src[reverse(I)...]
|
||||
return
|
||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||
y, ho = y_
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
# We don't have to make this assumption, it's just slightly more complex.
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||
y, ho, co = y_
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
@back(c, unbroadcast(h, dc))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||
dWi.', dWh.', db))
|
||||
end
|
||||
end
|
||||
|
|
|
@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
# Seem to need this for `accumulate`; try removing on 0.7
|
||||
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
using NNlib: conv, depthwiseconv
|
||||
|
||||
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
||||
"""
|
||||
Conv(size, in=>out)
|
||||
Conv(size, in=>out, relu)
|
||||
|
@ -21,14 +26,12 @@ struct Conv{N,F,A,V}
|
|||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation=1) where T =
|
||||
Conv(σ, w, b, stride, pad, dilation)
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k),
|
||||
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
|
|
|
@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
|
|||
end
|
||||
|
||||
"""
|
||||
binarycrossentropy(ŷ, y)
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
||||
|
||||
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
||||
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
||||
|
||||
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
||||
3-element Array{Float64,1}:
|
||||
|
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
|||
0.352317
|
||||
0.86167
|
||||
"""
|
||||
binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ)
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
module Optimise
|
||||
|
||||
export train!,
|
||||
SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
|
||||
struct Param{T}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
call(f, xs...) = f(xs...)
|
||||
|
||||
# note for optimisers: set to zero
|
||||
# p.Δ at the end of the weigths update
|
||||
# p.Δ at the end of the weights update
|
||||
function optimiser(ps, fs...)
|
||||
ps = [Param(p) for p in ps]
|
||||
fs = map(ps) do p
|
||||
|
@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
|||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
|
|
|
@ -5,6 +5,14 @@ function descent(p::Param, η::Real)
|
|||
end
|
||||
end
|
||||
|
||||
# Ref: https://arxiv.org/abs/1711.05101.pdf
|
||||
function descentweightdecay(p::Param, η::Real, γ::Real)
|
||||
function ()
|
||||
@. p.x = p.x - η * (p.Δ + γ * p.x)
|
||||
@. p.Δ = 0
|
||||
end
|
||||
end
|
||||
|
||||
function momentum(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
function ()
|
||||
|
|
|
@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(l) && error("Loss is Inf")
|
||||
isnan(l) && error("Loss is NaN")
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
|
|
|
@ -1,22 +1,27 @@
|
|||
module Tracker
|
||||
|
||||
using MacroTools
|
||||
using MacroTools: @q, @forward
|
||||
|
||||
import Base: ==
|
||||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
|
||||
|
||||
tracker(x) = nothing
|
||||
|
||||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||
grad(x) = grad(tracker(x))
|
||||
grad(::Void) = nothing
|
||||
data(x) = x
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
args::As
|
||||
end
|
||||
|
||||
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||
Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
|
||||
Call() = Call(nothing, ())
|
||||
|
||||
# When deserialising, the object_id changes
|
||||
a::Call == b::Call = a.func == b.func && a.args == b.args
|
||||
|
@ -27,33 +32,83 @@ mutable struct Tracked{T}
|
|||
ref::UInt32
|
||||
f::Call
|
||||
isleaf::Bool
|
||||
data::T
|
||||
grad::T
|
||||
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
||||
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
|
||||
end
|
||||
|
||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
||||
|
||||
track(f::Call, x) = Tracked(f, x)
|
||||
track(f::Call) = track(f, f())
|
||||
track(f, xs...) = track(Call(f, xs...))
|
||||
|
||||
istracked(x::Tracked) = true
|
||||
isleaf(x::Tracked) = x.f == Call(nothing)
|
||||
data(x::Tracked) = x.data
|
||||
isleaf(x::Tracked) = x.f == Call()
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||
|
||||
function _forward end
|
||||
|
||||
function track(f, xs...; kw...)
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
||||
macro grad(ex)
|
||||
@capture(shortdef(ex), (name_(args__) = body_) |
|
||||
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
||||
T == nothing && (T = [])
|
||||
isexpr(name, :(::)) || (name = :(::typeof($name)))
|
||||
insert!(args, 1+isexpr(args[1], :parameters) , name)
|
||||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||
end
|
||||
|
||||
function update!(x, Δ)
|
||||
x.data .+= data(Δ)
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("back.jl")
|
||||
include("scalar.jl")
|
||||
include("array.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
||||
|
||||
"""
|
||||
checkpoint(f, args...)
|
||||
|
||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||
backward pass. This can be used to save memory in larger models.
|
||||
"""
|
||||
checkpoint(f, args...) = track(checkpoint, f, args...)
|
||||
|
||||
@grad function checkpoint(f, args...)
|
||||
data(f(args...)), function (Δ)
|
||||
y, back = forward(f, args...)
|
||||
(nothing, back(Δ)...)
|
||||
end
|
||||
end
|
||||
|
||||
nobacksies(f, x) = track(nobacksies, f, x)
|
||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
@grad identity(x) = data(x), Δ -> (Δ,)
|
||||
param(x::TrackedReal) = track(identity, x)
|
||||
param(x::TrackedArray) = track(identity, x)
|
||||
|
||||
import NNlib.cudata
|
||||
import Adapt.adapt
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
|||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||
end
|
||||
|
||||
data(x::TrackedArray) = x.data
|
||||
tracker(x::TrackedArray) = x.tracker
|
||||
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
|
@ -15,12 +16,12 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
|||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
|
@ -49,6 +50,9 @@ for f in :[Base.size, Base.ndims].args
|
|||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
|
||||
size(data(x), i, j, is...)
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
|
@ -62,54 +66,57 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
|||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs.data)
|
||||
Δ′[i...] = Δ
|
||||
@back(xs, Δ′)
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
Δ′[i...] = data(Δ)
|
||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
|
||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
|
||||
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
|
||||
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
||||
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
|
||||
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
Δ′ = similar(xs.data)
|
||||
S = size(xs.data)
|
||||
for (i,v) in enumerate(Δ)
|
||||
@grad function repmat(xs, m, n = 1)
|
||||
repmat(data(xs), m, n), function (Δ)
|
||||
Δ′ = similar(xs)
|
||||
S = size(xs)
|
||||
for (i,v) in enumerate(data(Δ))
|
||||
d1 = divrem(i-1, S[1]*m)
|
||||
x = d1[2] % S[1]+1
|
||||
y = d1[1] % S[2]+1
|
||||
Δ′[x, y] += v
|
||||
end
|
||||
back(xs, Δ′)
|
||||
return (nobacksies(:repmat, Δ′), nothing, nothing)
|
||||
end
|
||||
end
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
|
||||
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
|
||||
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
S = size(xs)
|
||||
|
||||
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
|
||||
Δ′ = similar(xs.data)
|
||||
Δ′ .= 0
|
||||
S = size(xs.data)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
back(xs, Δ′)
|
||||
(nobacksies(:repeat, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
|
@ -138,42 +145,51 @@ for f in [:vcat, :hcat]
|
|||
end
|
||||
end
|
||||
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||
start += size(xsi, 1)
|
||||
@grad function vcat(xs...)
|
||||
vcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
d = Δ[start+1:start+size(xsi,1), i...]
|
||||
start += size(xsi, 1)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(hcat), Δ, xs...)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
if ndims(xsi) == 1
|
||||
@back(xsi, Δ[:, start+1])
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
@grad function hcat(xs...)
|
||||
hcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
d = if ndims(xsi) == 1
|
||||
Δ[:, start+1]
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
Δ[:, start+1:start+size(xsi,2), i...]
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
|
||||
function back(::typeof(cat), Δ, dims, Xs...)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
for xs in Xs
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
|
||||
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
|
||||
|
||||
start = start .+ till_xs
|
||||
@grad function cat(dims, Xs...)
|
||||
cat(dims, data.(Xs)...), function (Δ)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||
start = start .+ till_xs
|
||||
d
|
||||
end for xs in Xs]
|
||||
return (nothing, Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -181,11 +197,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
|||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims)))
|
||||
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
|
@ -207,14 +222,18 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
|||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
|
||||
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
|
@ -230,10 +249,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
|||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
function back(::typeof(dot), Δ, xs, ys)
|
||||
@back(xs, Δ.*data(ys))
|
||||
@back(ys, Δ.*data(xs))
|
||||
end
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
|
@ -244,41 +260,32 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
|||
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
|
||||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
||||
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
||||
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmax(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
@grad function maximum(xs, r...)
|
||||
maximum(data(xs), r...), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), r...)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
||||
end
|
||||
end
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmax(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmin(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmin(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
@grad function minimum(xs, r...)
|
||||
minimum(data(xs), r...), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), r...)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
||||
end
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
|
||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
|
@ -295,30 +302,14 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args
|
|||
end
|
||||
end
|
||||
|
||||
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
||||
@back(a, A_mul_Bt(Δ, data(b)))
|
||||
@back(b, At_mul_B(data(a), Δ))
|
||||
end
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
||||
|
||||
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||
@back(a, A_mul_Bt(Δ, data(b))')
|
||||
@back(b, data(a)*Δ)
|
||||
end
|
||||
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||
@back(a, Δ * data(b))
|
||||
@back(b, At_mul_B(data(a), Δ)')
|
||||
end
|
||||
|
||||
# Fast path for matrix-vector
|
||||
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
|
||||
if isleaf(W)
|
||||
W.grad .+= Δ .* data(x).'
|
||||
else
|
||||
back(W, A_mul_Bt(Δ, data(x)))
|
||||
end
|
||||
@back(x, At_mul_B(data(W), Δ))
|
||||
end
|
||||
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
# NNlib
|
||||
|
||||
|
@ -327,26 +318,11 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwisecon
|
|||
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
|
||||
|
||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||
|
||||
# TODO: can store kwargs efficiently in namedtuples
|
||||
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
|
||||
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
|
||||
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
end
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
_depthwiseconv(x, w, stride, pad) = depthwiseconv(x, w, stride = stride, pad = pad)
|
||||
|
||||
|
@ -362,61 +338,63 @@ function back(::typeof(_depthwiseconv), Δ, x, w, stride, pad)
|
|||
@back(w, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
|
||||
end
|
||||
|
||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_maxpool, x, k, pad, stride)
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:conv,
|
||||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_meanpool, x, k, pad, stride)
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
# Broadcasting
|
||||
|
||||
using ForwardDiff: Dual, partials
|
||||
|
||||
struct Broadcasted{F,T}
|
||||
f::F
|
||||
data::T
|
||||
end
|
||||
|
||||
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
||||
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
|
||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||
|
||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
b = Broadcasted(f, out)
|
||||
track(Call(b, args...), b())
|
||||
end
|
||||
unbroadcast(x::Tuple, Δ) =
|
||||
x == size(Δ) ? Δ :
|
||||
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||
|
||||
function getpartial(Δ, x, i)
|
||||
@inbounds p = getindex(partials(x), i)
|
||||
return Δ * p
|
||||
end
|
||||
|
||||
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
|
||||
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
|
||||
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
|
||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
sizes = size.(args)
|
||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
y = value.(out)
|
||||
back = function (Δ_)
|
||||
Δ = data(Δ_)
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||
nobacksies(:broadcast, dxs)
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
||||
|
@ -429,4 +407,4 @@ Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
|||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
||||
|
||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
|
||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
|
||||
|
|
|
@ -10,8 +10,6 @@ function scan(x::Tracked)
|
|||
if ref == 1
|
||||
scan(x.f)
|
||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||
else
|
||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
@ -21,9 +19,14 @@ function scan(x)
|
|||
return
|
||||
end
|
||||
|
||||
back_(f, y, args...) = back(f, args...)
|
||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||
back_(::Call{Void}, y, Δ) = nothing
|
||||
function back_(c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach(back, c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Void}, Δ) = nothing
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
@ -31,33 +34,121 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
|||
function back(x::Tracked, Δ)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
ref == 0 && back_(x.f, x.data, x.grad)
|
||||
if ref > 0 || isdefined(x, :grad)
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
else
|
||||
x.grad = Δ
|
||||
end
|
||||
ref == 0 && back_(x.f, x.grad)
|
||||
else
|
||||
ref == 0 && back_(x.f, x.data, Δ)
|
||||
ref == 0 && back_(x.f, Δ)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(x, Δ) = back(tracker(x), Δ)
|
||||
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
|
||||
macro back(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
istracked(x) && back(x, $(esc(Δ)))
|
||||
end
|
||||
end
|
||||
back(::Void, _) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x::Tracked, Δ)
|
||||
function back!(x, Δ)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
back(x, Δ)
|
||||
back(tracker(x), Δ)
|
||||
return
|
||||
end
|
||||
|
||||
back!(x, Δ) = back!(tracker(x), Δ)
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
params::IdSet
|
||||
Params(xs) = new(IdSet(xs))
|
||||
end
|
||||
|
||||
@forward Params.params Base.start, Base.next, Base.done
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.params, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
struct Grads
|
||||
grads::ObjectIdDict
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(ObjectIdDict())
|
||||
|
||||
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
x.isleaf && (accum!(g, x, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if ref > 0 || haskey(g, x)
|
||||
accum!(g, x, Δ)
|
||||
ref == 0 && back_(g, x.f, g[x])
|
||||
else
|
||||
ref == 0 && back_(g, x.f, Δ)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Grads, ::Void, _) = return
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
y, function (Δ)
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
back(g, tracker(y), Δ)
|
||||
end
|
||||
return g
|
||||
end
|
||||
end
|
||||
|
||||
function forward(f, args...)
|
||||
args = param.(args)
|
||||
y, back = forward(() -> f(args...), Params(args))
|
||||
y, Δ -> getindex.(back(Δ), args)
|
||||
end
|
||||
|
||||
function losscheck(x)
|
||||
x isa Real || error("Function output is not scalar")
|
||||
isinf(x) && error("Loss is infinite")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
end
|
||||
|
||||
function gradient(f, args...)
|
||||
y, back = forward(f, args...)
|
||||
losscheck(y)
|
||||
return back(1)
|
||||
end
|
||||
|
||||
derivative(f, x) = gradient(f, x)[1]
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
struct IdSet{T} <: AbstractSet{T}
|
||||
dict::ObjectIdDict
|
||||
IdSet{T}() where T = new(ObjectIdDict())
|
||||
end
|
||||
|
||||
Base.eltype{T}(::IdSet{T}) = T
|
||||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
|
||||
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
||||
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
|
||||
|
||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||
|
||||
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
|
||||
Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
||||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
Base.start(s::IdSet) = start(keys(s.dict))
|
||||
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
||||
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
|
@ -1,9 +1,3 @@
|
|||
function gradient(f, xs::AbstractArray...)
|
||||
xs = param.(xs)
|
||||
back!(f(xs...))
|
||||
grad.(xs)
|
||||
end
|
||||
|
||||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zeros.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
|
@ -21,4 +15,4 @@ end
|
|||
|
||||
gradcheck(f, xs...) =
|
||||
all(isapprox.(ngradient(f, xs...),
|
||||
gradient(f, xs...), rtol = 1e-5, atol = 1e-5))
|
||||
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
struct TrackedReal{T<:Real} <: Real
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
|
||||
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
|
||||
|
||||
data(x::TrackedReal) = x.data
|
||||
tracker(x::TrackedReal) = x.tracker
|
||||
|
||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||
|
||||
back!(x::TrackedReal) = back!(x, 1)
|
||||
function back!(x::TrackedReal)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
return back!(x, 1)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
show(io, data(x))
|
||||
|
@ -19,15 +25,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
|||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
# This cuts derivatives, fix if needed.
|
||||
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
|
||||
for f in :[isinf, isnan, isfinite].args
|
||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||
end
|
||||
|
@ -42,23 +49,21 @@ using DiffRules, SpecialFunctions, NaNMath
|
|||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 1 || continue
|
||||
@eval begin
|
||||
@grad $M.$f(a::Real) =
|
||||
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
|
||||
$M.$f(a::TrackedReal) = track($M.$f, a)
|
||||
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
|
||||
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
|
||||
end
|
||||
end
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 2 || continue
|
||||
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
||||
f = :($M.$f)
|
||||
@eval begin
|
||||
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
|
||||
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
|
||||
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
|
||||
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
|
||||
@back(a, Δ * $da)
|
||||
@back(b, Δ * $db)
|
||||
end
|
||||
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -70,16 +75,18 @@ import Base:^
|
|||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
data(xs::TrackedTuple) = xs.data
|
||||
tracker(xs::TrackedTuple) = xs.tracker
|
||||
|
||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
|
@ -90,20 +97,21 @@ Base.length(x::TrackedTuple) = length(data(x))
|
|||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
|
||||
back(::typeof(getindex), Δ, t, i) =
|
||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||
@grad function getindex(xs::TrackedTuple, i)
|
||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, xs), data.(xs))
|
||||
track(Call(collect, (tracker.(xs),)), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back(::typeof(collect), Δ, xs)
|
||||
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
|
||||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach(back, c.args[1], data(Δ))
|
||||
end
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import Adapt: adapt
|
||||
import .Tracker: IdSet
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict())
|
|||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
using DataFlow: OSet
|
||||
|
||||
function prefor(f, x; seen = OSet())
|
||||
function prefor(f, x; seen = IdSet())
|
||||
x ∈ seen && return
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
using Base.Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
|
||||
const ϵ = 1e-7
|
||||
|
||||
@testset "losses" begin
|
||||
# First, regression-style y's
|
||||
y = [1, 1, 0, 0]
|
||||
|
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
|||
|
||||
logŷ, y = randn(3), rand(3)
|
||||
@testset "binarycrossentropy" begin
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||
end
|
||||
|
||||
|
||||
@testset "logitbinarycrossentropy" begin
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y)
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv, depthwiseconv
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
|
@ -111,6 +111,7 @@ end
|
|||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
# TODO unreliable
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
|
@ -234,4 +235,24 @@ Tracker.back!(b)
|
|||
@test grad.((x,y)) == (3, 2)
|
||||
end
|
||||
|
||||
# Gradient Hooks
|
||||
@testset "Hooks" begin
|
||||
x = param(2)
|
||||
y = Tracker.hook(-, x)
|
||||
back!(y)
|
||||
@test grad(x) == -1
|
||||
end
|
||||
|
||||
@testset "Checkpointing" begin
|
||||
count = 0
|
||||
function mul(a, b)
|
||||
count += 1
|
||||
a * b
|
||||
end
|
||||
@test derivative(x -> mul(5, x), 3) == 5
|
||||
@test count == 1
|
||||
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
|
||||
@test count == 3
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue