doctests passing
This commit is contained in:
parent
67c38b3099
commit
c8d460ff84
|
@ -33,7 +33,8 @@ Zygote = "0.3"
|
||||||
julia = "1.1"
|
julia = "1.1"
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
|
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
|
||||||
[targets]
|
[targets]
|
||||||
test = ["Test"]
|
test = ["Test", "Documenter"]
|
||||||
|
|
|
@ -5,55 +5,56 @@
|
||||||
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
|
||||||
|
|
||||||
```jldoctest basics
|
```jldoctest basics
|
||||||
julia> using Flux.Tracker
|
julia> using Flux
|
||||||
|
|
||||||
julia> f(x) = 3x^2 + 2x + 1;
|
julia> f(x) = 3x^2 + 2x + 1;
|
||||||
|
|
||||||
julia> df(x) = Tracker.gradient(f, x; nest = true)[1]; # df/dx = 6x + 2
|
julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
|
||||||
|
|
||||||
julia> df(2)
|
julia> df(2)
|
||||||
14.0 (tracked)
|
14
|
||||||
|
|
||||||
julia> d2f(x) = Tracker.gradient(df, x; nest = true)[1]; # d²f/dx² = 6
|
julia> d2f(x) = gradient(df, x)[1]; # d²f/dx² = 6
|
||||||
|
|
||||||
julia> d2f(2)
|
julia> d2f(2)
|
||||||
6.0 (tracked)
|
6
|
||||||
```
|
```
|
||||||
|
|
||||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
When a function has many parameters, we can get gradients of each one at the same time:
|
||||||
|
|
||||||
When a function has many parameters, we can pass them all in explicitly:
|
|
||||||
|
|
||||||
```jldoctest basics
|
```jldoctest basics
|
||||||
julia> f(W, b, x) = W * x + b;
|
julia> f(x, y) = sum((x .- y).^2);
|
||||||
|
|
||||||
julia> Tracker.gradient(f, 2, 3, 4)
|
julia> gradient(f, [2, 1], [2, 0])
|
||||||
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
([0, 2], [0, -2])
|
||||||
```
|
```
|
||||||
|
|
||||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
But machine learning models can have *hundreds* of parameters! To handle this, Flux lets you work with collections of parameters, via `params`. You can get the gradient of all parameters used in a program without explicitly passing them in.
|
||||||
|
|
||||||
```jldoctest basics
|
```jldoctest basics
|
||||||
julia> using Flux
|
julia> using Flux
|
||||||
|
|
||||||
julia> W = param(2)
|
julia> x = [2, 1];
|
||||||
2.0 (tracked)
|
|
||||||
|
|
||||||
julia> b = param(3)
|
julia> y = [2, 0];
|
||||||
3.0 (tracked)
|
|
||||||
|
|
||||||
julia> f(x) = W * x + b;
|
julia> gs = gradient(params(x, y)) do
|
||||||
|
f(x, y)
|
||||||
|
end
|
||||||
|
Grads(...)
|
||||||
|
|
||||||
julia> grads = Tracker.gradient(() -> f(4), params(W, b));
|
julia> gs[x]
|
||||||
|
2-element Array{Int64,1}:
|
||||||
|
0
|
||||||
|
2
|
||||||
|
|
||||||
julia> grads[W]
|
julia> gs[y]
|
||||||
4.0 (tracked)
|
2-element Array{Int64,1}:
|
||||||
|
0
|
||||||
julia> grads[b]
|
-2
|
||||||
1.0 (tracked)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||||
|
|
||||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||||
|
|
||||||
|
@ -76,26 +77,20 @@ x, y = rand(5), rand(2) # Dummy data
|
||||||
loss(x, y) # ~ 3
|
loss(x, y) # ~ 3
|
||||||
```
|
```
|
||||||
|
|
||||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above.
|
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker
|
using Flux
|
||||||
|
|
||||||
W = param(W)
|
gs = gradient(() -> loss(x, y), params(W, b))
|
||||||
b = param(b)
|
|
||||||
|
|
||||||
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
Now that we have gradients, we can pull them out and update `W` to train the model.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker: update!
|
W̄ = gs[W]
|
||||||
|
|
||||||
Δ = gs[W]
|
W .-= 0.1 .* W̄
|
||||||
|
|
||||||
# Update the parameter and reset the gradient
|
|
||||||
update!(W, -0.1Δ)
|
|
||||||
|
|
||||||
loss(x, y) # ~ 2.5
|
loss(x, y) # ~ 2.5
|
||||||
```
|
```
|
||||||
|
@ -111,12 +106,12 @@ It's common to create more complex models than the linear regression above. For
|
||||||
```julia
|
```julia
|
||||||
using Flux
|
using Flux
|
||||||
|
|
||||||
W1 = param(rand(3, 5))
|
W1 = rand(3, 5)
|
||||||
b1 = param(rand(3))
|
b1 = rand(3)
|
||||||
layer1(x) = W1 * x .+ b1
|
layer1(x) = W1 * x .+ b1
|
||||||
|
|
||||||
W2 = param(rand(2, 3))
|
W2 = rand(2, 3)
|
||||||
b2 = param(rand(2))
|
b2 = rand(2)
|
||||||
layer2(x) = W2 * x .+ b2
|
layer2(x) = W2 * x .+ b2
|
||||||
|
|
||||||
model(x) = layer2(σ.(layer1(x)))
|
model(x) = layer2(σ.(layer1(x)))
|
||||||
|
@ -128,8 +123,8 @@ This works but is fairly unwieldy, with a lot of repetition – especially as we
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
function linear(in, out)
|
function linear(in, out)
|
||||||
W = param(randn(out, in))
|
W = randn(out, in)
|
||||||
b = param(randn(out))
|
b = randn(out)
|
||||||
x -> W * x .+ b
|
x -> W * x .+ b
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -150,7 +145,7 @@ struct Affine
|
||||||
end
|
end
|
||||||
|
|
||||||
Affine(in::Integer, out::Integer) =
|
Affine(in::Integer, out::Integer) =
|
||||||
Affine(param(randn(out, in)), param(randn(out)))
|
Affine(randn(out, in), randn(out))
|
||||||
|
|
||||||
# Overload call, so the object can be used as a function
|
# Overload call, so the object can be used as a function
|
||||||
(m::Affine)(x) = m.W * x .+ m.b
|
(m::Affine)(x) = m.W * x .+ m.b
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Iris
|
|
||||||
|
|
||||||
Fisher's classic iris dataset.
|
Fisher's classic iris dataset.
|
||||||
|
|
||||||
Measurements from 3 different species of iris: setosa, versicolor and
|
Measurements from 3 different species of iris: setosa, versicolor and
|
||||||
virginica. There are 50 examples of each species.
|
virginica. There are 50 examples of each species.
|
||||||
|
|
||||||
There are 4 measurements for each example: sepal length, sepal width, petal
|
There are 4 measurements for each example: sepal length, sepal width, petal
|
||||||
length and petal width. The measurements are in centimeters.
|
length and petal width. The measurements are in centimeters.
|
||||||
|
|
||||||
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
||||||
|
@ -35,10 +31,12 @@ end
|
||||||
|
|
||||||
labels()
|
labels()
|
||||||
|
|
||||||
Get the labels of the iris dataset, a 150 element array of strings listing the
|
Get the labels of the iris dataset, a 150 element array of strings listing the
|
||||||
species of each example.
|
species of each example.
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux
|
||||||
|
|
||||||
julia> labels = Flux.Data.Iris.labels();
|
julia> labels = Flux.Data.Iris.labels();
|
||||||
|
|
||||||
julia> summary(labels)
|
julia> summary(labels)
|
||||||
|
@ -58,11 +56,13 @@ end
|
||||||
|
|
||||||
features()
|
features()
|
||||||
|
|
||||||
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
||||||
elements. It has a row for each feature (sepal length, sepal width,
|
elements. It has a row for each feature (sepal length, sepal width,
|
||||||
petal length, petal width) and a column for each example.
|
petal length, petal width) and a column for each example.
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux
|
||||||
|
|
||||||
julia> features = Flux.Data.Iris.features();
|
julia> features = Flux.Data.Iris.features();
|
||||||
|
|
||||||
julia> summary(features)
|
julia> summary(features)
|
||||||
|
@ -81,6 +81,5 @@ function features()
|
||||||
iris = readdlm(deps("iris.data"), ',')
|
iris = readdlm(deps("iris.data"), ',')
|
||||||
Matrix{Float64}(iris[1:end, 1:4]')
|
Matrix{Float64}(iris[1:end, 1:4]')
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -54,17 +54,19 @@ it will error.
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux: onehot
|
||||||
|
|
||||||
julia> onehot(:b, [:a, :b, :c])
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
false
|
0
|
||||||
true
|
1
|
||||||
false
|
0
|
||||||
|
|
||||||
julia> onehot(:c, [:a, :b, :c])
|
julia> onehot(:c, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
false
|
0
|
||||||
false
|
0
|
||||||
true
|
1
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
function onehot(l, labels)
|
function onehot(l, labels)
|
||||||
|
@ -88,12 +90,13 @@ Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `label
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
julia> using Flux: onehotbatch
|
||||||
3×3 Flux.OneHotMatrix:
|
|
||||||
false true false
|
|
||||||
true false true
|
|
||||||
false false false
|
|
||||||
|
|
||||||
|
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||||
|
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
|
||||||
|
0 1 0
|
||||||
|
1 0 1
|
||||||
|
0 0 0
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
|
@ -106,9 +109,9 @@ Base.argmax(xs::OneHotVector) = xs.ix
|
||||||
|
|
||||||
Inverse operations of [`onehot`](@ref).
|
Inverse operations of [`onehot`](@ref).
|
||||||
|
|
||||||
## Examples
|
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
julia> using Flux: onecold
|
||||||
|
|
||||||
julia> onecold([true, false, false], [:a, :b, :c])
|
julia> onecold([true, false, false], [:a, :b, :c])
|
||||||
:a
|
:a
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
using Flux, Test, Random, Statistics
|
using Flux, Test, Random, Statistics, Documenter
|
||||||
using Random
|
using Random
|
||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
|
||||||
# So we can use the system CuArrays
|
|
||||||
insert!(LOAD_PATH, 2, "@v#.#")
|
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
@info "Testing Basics"
|
@info "Testing Basics"
|
||||||
|
@ -32,4 +29,6 @@ else
|
||||||
@warn "CUDA unavailable, not testing GPU support"
|
@warn "CUDA unavailable, not testing GPU support"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
doctest(Flux)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue