Merge branch 'master' into cat-fix
This commit is contained in:
commit
df06c3351d
@ -8,7 +8,6 @@ julia:
|
||||
# uncomment the following lines to override the default test script
|
||||
script:
|
||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
- julia -e 'Pkg.clone("https://github.com/FluxML/NNlib.jl")'
|
||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
after_success:
|
||||
- julia -e 'Pkg.add("Documenter")'
|
||||
|
@ -1,6 +1,6 @@
|
||||
using Documenter, Flux
|
||||
using Documenter, Flux, NNlib
|
||||
|
||||
makedocs(modules=[Flux],
|
||||
makedocs(modules=[Flux, NNlib],
|
||||
doctest = false,
|
||||
format = :html,
|
||||
analytics = "UA-36890222-9",
|
||||
@ -10,13 +10,13 @@ makedocs(modules=[Flux],
|
||||
"Building Models" =>
|
||||
["Basics" => "models/basics.md",
|
||||
"Recurrence" => "models/recurrence.md",
|
||||
"Layer Reference" => "models/layers.md"],
|
||||
"Model Reference" => "models/layers.md"],
|
||||
"Training Models" =>
|
||||
["Optimisers" => "training/optimisers.md",
|
||||
"Training" => "training/training.md"],
|
||||
"One-Hot Encoding" => "data/onehot.md",
|
||||
"GPU Support" => "gpu.md",
|
||||
"Contributing & Help" => "contributing.md"])
|
||||
"Community" => "community.md"])
|
||||
|
||||
deploydocs(
|
||||
repo = "github.com/FluxML/Flux.jl.git",
|
||||
|
5
docs/src/community.md
Normal file
5
docs/src/community.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Community
|
||||
|
||||
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby). If you have questions or issues we'll try to help you out.
|
||||
|
||||
If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started.
|
@ -1,9 +0,0 @@
|
||||
# Contributing & Help
|
||||
|
||||
If you need help, please ask on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby).
|
||||
|
||||
Right now, the best way to help out is to try out the examples and report any issues or missing features as you find them. The second best way is to help us spread the word, perhaps by [starring the repo](https://github.com/MikeInnes/Flux.jl).
|
||||
|
||||
If you're interested in hacking on Flux, most of the [code](https://github.com/MikeInnes/Flux.jl/tree/master/src) is pretty straightforward. Adding new [layer definitions](https://github.com/MikeInnes/Flux.jl/tree/master/src/layers) or cost functions is simple using the Flux DSL itself, and things like data utilities and training processes are all plain Julia code.
|
||||
|
||||
If you get stuck or need anything, let us know!
|
@ -19,16 +19,16 @@ loss(x, y) # ~ 3
|
||||
|
||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapparams`, which allows you to alter all parameters of a model at once.
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
|
||||
|
||||
```julia
|
||||
d = Dense(10, 5, σ)
|
||||
d = mapparams(cu, d)
|
||||
d = mapleaves(cu, d)
|
||||
d.W # Tracked CuArray
|
||||
d(cu(rand(10))) # CuArray output
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
m = mapparams(cu, m)
|
||||
m = mapleaves(cu, m)
|
||||
d(cu(rand(10)))
|
||||
```
|
||||
|
||||
|
@ -18,7 +18,7 @@ loss(x, y) # ~ 3
|
||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: param, back!, data, grad
|
||||
using Flux.Tracker
|
||||
|
||||
W = param(W)
|
||||
b = param(b)
|
||||
@ -31,9 +31,10 @@ back!(l)
|
||||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||
|
||||
```julia
|
||||
grad(W)
|
||||
W.grad
|
||||
|
||||
W.data .-= 0.1grad(W)
|
||||
# Update the parameter
|
||||
W.data .-= 0.1(W.grad)
|
||||
|
||||
loss(x, y) # ~ 2.5
|
||||
```
|
||||
|
@ -1,6 +1,32 @@
|
||||
## Model Layers
|
||||
## Basic Layers
|
||||
|
||||
These core layers form the foundation of almost all neural networks.
|
||||
|
||||
```@docs
|
||||
Chain
|
||||
Dense
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
||||
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
|
||||
|
||||
```@docs
|
||||
RNN
|
||||
LSTM
|
||||
Flux.Recur
|
||||
```
|
||||
|
||||
## Activation Functions
|
||||
|
||||
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
||||
|
||||
Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
|
||||
|
||||
```@docs
|
||||
σ
|
||||
relu
|
||||
leakyrelu
|
||||
elu
|
||||
swish
|
||||
```
|
||||
|
@ -15,9 +15,9 @@ Recurrent networks introduce a *hidden state* that gets carried over each time w
|
||||
|
||||
```julia
|
||||
h = # ... initial state ...
|
||||
y₁, h = f(x₁, h)
|
||||
y₂, h = f(x₂, h)
|
||||
y₃, h = f(x₃, h)
|
||||
h, y₁ = f(h, x₁)
|
||||
h, y₂ = f(h, x₂)
|
||||
h, y₃ = f(h, x₃)
|
||||
# ...
|
||||
```
|
||||
|
||||
@ -25,7 +25,7 @@ Information stored in `h` is preserved for the next prediction, allowing it to f
|
||||
|
||||
(This might be important if, for example, each `x` represents one word of a sentence; the model's interpretation of the word "bank" should change if the previous input was "river" rather than "investment".)
|
||||
|
||||
Flux's RNN support closely follows this mathematical perspective. The most basic RNN is as close as possible to a standard `Dense` layer, and the output and hidden state are the same. By convention, the hidden state is the first input and output.
|
||||
Flux's RNN support closely follows this mathematical perspective. The most basic RNN is as close as possible to a standard `Dense` layer, and the output is also the hidden state.
|
||||
|
||||
```julia
|
||||
Wxh = randn(5, 10)
|
||||
@ -112,3 +112,5 @@ truncate!(m)
|
||||
```
|
||||
|
||||
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
|
||||
|
||||
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.
|
||||
|
@ -17,14 +17,11 @@ back!(l)
|
||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: data, grad
|
||||
|
||||
function update()
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
x, Δ = data(p), grad(p)
|
||||
x .-= η .* Δ # Apply the update
|
||||
Δ .= 0 # Clear the gradient
|
||||
p.data .-= η .* p.grad # Apply the update
|
||||
p.grad .= 0 # Clear the gradient
|
||||
end
|
||||
end
|
||||
```
|
||||
@ -48,7 +45,21 @@ For the update step, there's nothing whatsoever wrong with writing the loop abov
|
||||
```julia
|
||||
opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
|
||||
|
||||
opt()
|
||||
opt() # Carry out the update, modifying `W` and `b`.
|
||||
```
|
||||
|
||||
An optimiser takes a parameter list and returns a function that does the same thing as `update` above. We can pass either `opt` or `update` to our [training loop](training.md), which will then run the optimiser after every mini-batch of data.
|
||||
|
||||
## Optimiser Reference
|
||||
|
||||
All optimisers return a function that, when called, will update the parameters passed to it.
|
||||
|
||||
```@docs
|
||||
SGD
|
||||
Momentum
|
||||
Nesterov
|
||||
RMSProp
|
||||
ADAM
|
||||
ADAGrad
|
||||
ADADelta
|
||||
```
|
||||
|
@ -30,7 +30,33 @@ loss(x, y) = Flux.mse(m(x), y)
|
||||
Flux.train!(loss, data, opt)
|
||||
```
|
||||
|
||||
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `logloss` for cross entropy loss, but you can calculate it however you want.
|
||||
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||
|
||||
## Datasets
|
||||
|
||||
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
|
||||
|
||||
```julia
|
||||
x = rand(784)
|
||||
y = rand(10)
|
||||
data = [(x, y)]
|
||||
```
|
||||
|
||||
`Flux.train!` will call `loss(x, y)`, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:
|
||||
|
||||
```julia
|
||||
data = [(x, y), (x, y), (x, y)]
|
||||
# Or equivalently
|
||||
data = Iterators.repeated((x, y), 3)
|
||||
```
|
||||
|
||||
It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
|
||||
|
||||
```julia
|
||||
xs = [rand(784), rand(784), rand(784)]
|
||||
ys = [rand( 10), rand( 10), rand( 10)]
|
||||
data = zip(xs, ys)
|
||||
```
|
||||
|
||||
## Callbacks
|
||||
|
||||
|
@ -8,10 +8,11 @@ using Juno, Requires
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
SGD, params, mapparams
|
||||
SGD, ADAM, Momentum, Nesterov,
|
||||
param, params, mapleaves
|
||||
|
||||
using NNlib
|
||||
export σ, relu, softmax
|
||||
export σ, relu, leakyrelu, elu, swish, softmax
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
|
@ -4,12 +4,14 @@
|
||||
Chain multiple layers / functions together, so that they are called in sequence
|
||||
on a given input.
|
||||
|
||||
m = Chain(x -> x^2, x -> x+1)
|
||||
m(5) == 26
|
||||
```julia
|
||||
m = Chain(x -> x^2, x -> x+1)
|
||||
m(5) == 26
|
||||
|
||||
m = Chain(Dense(10, 5), Dense(5, 2))
|
||||
x = rand(10)
|
||||
m(x) == m[2](m[1](x))
|
||||
m = Chain(Dense(10, 5), Dense(5, 2))
|
||||
x = rand(10)
|
||||
m(x) == m[2](m[1](x))
|
||||
```
|
||||
|
||||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||
`m[1:3](x)` will calculate the output of the first three layers.
|
||||
@ -43,7 +45,17 @@ Creates a traditional `Dense` layer with parameters `W` and `b`.
|
||||
y = σ.(W * x .+ b)
|
||||
|
||||
The input `x` must be a vector of length `in`, or a batch of vectors represented
|
||||
as an `in × N` matrix. The out `y` will be a vector or batch of length `in`.
|
||||
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
|
||||
|
||||
```julia
|
||||
julia> d = Dense(5, 2)
|
||||
Dense(5, 2)
|
||||
|
||||
julia> d(rand(5))
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
0.00257447
|
||||
-0.00449443
|
||||
```
|
||||
"""
|
||||
struct Dense{F,S,T}
|
||||
σ::F
|
||||
|
@ -3,12 +3,33 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
|
||||
|
||||
# Stateful recurrence
|
||||
|
||||
"""
|
||||
Recur(cell)
|
||||
|
||||
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
|
||||
in the background. `cell` should be a model of the form:
|
||||
|
||||
h, y = cell(h, x...)
|
||||
|
||||
For example, here's a recurrent network that keeps a running total of its inputs.
|
||||
|
||||
```julia
|
||||
accum(h, x) = (h+x, x)
|
||||
rnn = Flux.Recur(accum, 0)
|
||||
rnn(2) # 2
|
||||
rnn(3) # 3
|
||||
rnn.state # 5
|
||||
rnn.(1:10) # apply to a sequence
|
||||
rnn.state # 60
|
||||
```
|
||||
"""
|
||||
mutable struct Recur{T}
|
||||
cell::T
|
||||
init
|
||||
state
|
||||
end
|
||||
|
||||
Recur(m) = Recur(m, hidden(m))
|
||||
Recur(m, h = hidden(m)) = Recur(m, h, h)
|
||||
|
||||
function (m::Recur)(xs...)
|
||||
h, y = m.cell(m.state, xs...)
|
||||
@ -20,12 +41,34 @@ treelike(Recur)
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
_truncate(x::AbstractArray) = x
|
||||
_truncate(x::TrackedArray) = x.data
|
||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
truncate!(m) = foreach(truncate!, children(m))
|
||||
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
||||
"""
|
||||
truncate!(rnn)
|
||||
|
||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
||||
state is preserved. See also `reset!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = Tracker.data(rnn.state)
|
||||
"""
|
||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
||||
|
||||
"""
|
||||
reset!(rnn)
|
||||
|
||||
Reset the hidden state of a recurrent layer back to its original value. See also
|
||||
`truncate!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = hidden(rnn.cell)
|
||||
"""
|
||||
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
|
||||
|
||||
flip(f, xs) = reverse(f.(reverse(xs)))
|
||||
|
||||
# Vanilla RNN
|
||||
|
||||
@ -50,6 +93,12 @@ function Base.show(io::IO, m::RNNCell)
|
||||
print(io, "RNNCell(", m.d, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
RNN(in::Integer, out::Integer, σ = tanh)
|
||||
|
||||
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
|
||||
output fed back into the input each time step.
|
||||
"""
|
||||
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||
|
||||
# LSTM
|
||||
@ -89,4 +138,13 @@ Base.show(io::IO, m::LSTMCell) =
|
||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||
size(m.forget.W, 1), ')')
|
||||
|
||||
"""
|
||||
LSTM(in::Integer, out::Integer, σ = tanh)
|
||||
|
||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
"""
|
||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
|
@ -1,7 +1,14 @@
|
||||
# Cost functions
|
||||
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
# back!(::typeof(mse), Δ, ŷ, y) = Δ .* (ŷ .- y)
|
||||
|
||||
logloss(ŷ, y) = -sum(y .* log.(ŷ)) / size(y, 2)
|
||||
# back!(::typeof(logloss), Δ, ŷ, y) = 0 .- Δ .* y ./ ŷ
|
||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) =
|
||||
-sum(y .* log.(ŷ)) / size(y, 2)
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
||||
logŷ = logŷ .- maximum(logŷ, 1)
|
||||
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
|
||||
-sum(y .* ypred) / size(y, 2)
|
||||
end
|
||||
|
@ -20,7 +20,9 @@ Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||
|
||||
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||
|
||||
import NNlib.adapt
|
||||
|
||||
@ -32,7 +34,12 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
||||
function onehot(l, labels)
|
||||
i = findfirst(labels, l)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
|
||||
|
||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
||||
|
@ -1,7 +1,7 @@
|
||||
module Optimise
|
||||
|
||||
export update!, params, train!,
|
||||
SGD
|
||||
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
@ -16,6 +16,6 @@ include("train.jl")
|
||||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
|
||||
end
|
||||
|
@ -9,10 +9,65 @@ function optimiser(ps, fs...)
|
||||
() -> foreach(call, fs)
|
||||
end
|
||||
|
||||
SGD(ps, η = 1) = optimiser(ps, p -> descent(p, η))
|
||||
ADAM(ps, η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0.0) = optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
Momentum(ps,ρ, decay = 0.0) = optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
Nesterov(ps,ρ, decay = 0.0) = optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
RMSProp(ps, η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
ADAGrad(ps, η = 0.01, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
ADADelta(ps, η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
"""
|
||||
SGD(params, η = 1; decay = 0)
|
||||
|
||||
Classic gradient descent optimiser. For each parameter `p` and its
|
||||
gradient `δp`, this runs `p -= η*δp`.
|
||||
|
||||
Supports decayed learning rate decay if the `decay` argument is provided.
|
||||
"""
|
||||
SGD(ps, η = 1; decay = 0) =
|
||||
optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η))
|
||||
|
||||
"""
|
||||
Momentum(params, ρ, decay = 0)
|
||||
|
||||
SGD with momentum `ρ` and optional learning rate decay.
|
||||
"""
|
||||
Momentum(ps, ρ; decay = 0) =
|
||||
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
Nesterov(params, ρ, decay = 0)
|
||||
|
||||
SGD with Nesterov momentum `ρ` and optional learning rate decay.
|
||||
"""
|
||||
Nesterov(ps, ρ; decay = 0) =
|
||||
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0)
|
||||
|
||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||
choice for recurrent networks.
|
||||
"""
|
||||
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
||||
optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
"""
|
||||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0)
|
||||
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
Parameters don't need tuning.
|
||||
"""
|
||||
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) =
|
||||
optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0)
|
||||
|
||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
|
||||
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
@ -4,10 +4,19 @@ using Flux.Tracker: back!
|
||||
tocb(f) = f
|
||||
tocb(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
function train!(m, data, opt; cb = () -> ())
|
||||
"""
|
||||
train!(loss, data, opt; cb = () -> ())
|
||||
|
||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||
backpropagation and calls the optimizer `opt` and the callback `cb`
|
||||
(i.e. `opt()` and `cb()`).
|
||||
|
||||
Multiple callbacks can be passed to `cb` as an array.
|
||||
"""
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
cb = tocb(cb)
|
||||
@progress for x in data
|
||||
l = m(x...)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(l.data[]) && error("Loss is Inf")
|
||||
isnan(l.data[]) && error("Loss is NaN")
|
||||
back!(l)
|
||||
|
@ -1,7 +1,5 @@
|
||||
module Tracker
|
||||
|
||||
using Base: RefValue
|
||||
|
||||
export TrackedArray, param, back!
|
||||
|
||||
data(x) = x
|
||||
@ -16,11 +14,13 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
ref::RefValue{UInt32}
|
||||
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
data::A
|
||||
grad::RefValue{A}
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
|
||||
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
|
||||
end
|
||||
|
||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||
@ -28,19 +28,22 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ)
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||
|
||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
param(xs::Real) = param(fill(xs))
|
||||
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad[]
|
||||
grad(x::TrackedArray) = x.grad
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
@ -73,8 +76,6 @@ include("numeric.jl")
|
||||
|
||||
import NNlib.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) =
|
||||
TrackedArray(xs.f, adapt(T, xs.data),
|
||||
RefValue(adapt(T, grad(xs))))
|
||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||
|
||||
end
|
||||
|
@ -3,11 +3,11 @@ scan(x) = nothing
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::TrackedArray)
|
||||
ref = x.ref[] += 1
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
else
|
||||
isassigned(x.grad) || (x.grad[] = zeros(x.data))
|
||||
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
||||
end
|
||||
return
|
||||
end
|
||||
@ -16,10 +16,10 @@ back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
||||
back(::Call{Void}, Δ) = nothing
|
||||
|
||||
function back(x::TrackedArray, Δ)
|
||||
ref = x.ref[] -= 1
|
||||
if isassigned(x.grad)
|
||||
x.grad[] .+= Δ
|
||||
ref == 0 && back(x.f, x.grad[])
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad .+= Δ
|
||||
ref == 0 && back(x.f, x.grad)
|
||||
else
|
||||
ref == 0 && back(x.f, Δ)
|
||||
end
|
||||
|
27
src/tree.jl
27
src/tree.jl
@ -8,18 +8,27 @@ function treelike(T, fs = fieldnames(T))
|
||||
end
|
||||
end
|
||||
|
||||
# TODO: prewalk/postwalk with correct caching
|
||||
# This is only correct in general for idempotent functions
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
mapparams(f, x::AbstractArray) = f(x)
|
||||
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
|
||||
function mapleaves(f, x; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
|
||||
export mapparams
|
||||
@deprecate mapparams(f, x) mapleaves(f, x)
|
||||
|
||||
using DataFlow: OSet
|
||||
|
||||
function params(m)
|
||||
ps = OSet()
|
||||
forparams(p -> push!(ps, p), m)
|
||||
return collect(ps)
|
||||
function prefor(f, x; seen = OSet())
|
||||
x ∈ seen && return
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
return
|
||||
end
|
||||
|
||||
function params(m)
|
||||
ps = []
|
||||
prefor(p -> p isa TrackedArray && push!(ps, p), m)
|
||||
return ps
|
||||
end
|
||||
|
61
src/utils.jl
61
src/utils.jl
@ -9,6 +9,67 @@ unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)
|
||||
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
|
||||
"""
|
||||
chunk(xs, n)
|
||||
|
||||
Split `xs` into `n` parts.
|
||||
|
||||
```julia
|
||||
julia> chunk(1:10, 3)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2, 3, 4]
|
||||
[5, 6, 7, 8]
|
||||
[9, 10]
|
||||
```
|
||||
"""
|
||||
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
|
||||
|
||||
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i)
|
||||
|
||||
"""
|
||||
batch(xs)
|
||||
|
||||
Batch the arrays in `xs` into a single array.
|
||||
|
||||
```julia
|
||||
julia> batch([[1,2,3],[4,5,6]])
|
||||
3×2 Array{Int64,2}:
|
||||
1 4
|
||||
2 5
|
||||
3 6
|
||||
```
|
||||
"""
|
||||
function batch(xs)
|
||||
data = first(xs) isa AbstractArray ?
|
||||
similar(first(xs), size(first(xs))..., length(xs)) :
|
||||
Vector{eltype(xs)}(length(xs))
|
||||
for (i, x) in enumerate(xs)
|
||||
data[batchindex(data, i)...] = x
|
||||
end
|
||||
return data
|
||||
end
|
||||
|
||||
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
|
||||
|
||||
"""
|
||||
batchseq(seqs, pad)
|
||||
|
||||
Take a list of `N` sequences, and turn them into a single sequence where each
|
||||
item is a batch of `N`. Short sequences will be padded by `pad`.
|
||||
|
||||
```julia
|
||||
julia> batchseq([[1, 2, 3], [4, 5]], 0)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 4]
|
||||
[2, 5]
|
||||
[3, 0]
|
||||
```
|
||||
"""
|
||||
function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
|
||||
xs_ = [rpad(x, n, pad) for x in xs]
|
||||
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
|
||||
end
|
||||
|
||||
# Other
|
||||
|
||||
function accuracy(m, data)
|
||||
|
@ -15,7 +15,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.logloss, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@ -27,4 +27,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
2y + x
|
||||
end
|
||||
|
||||
for T in [Float32, Float64]
|
||||
@test isa(param(T(1)), TrackedArray{T, 0})
|
||||
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
|
||||
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
|
||||
end
|
||||
|
||||
# TODO: do we wand this behaviour ??
|
||||
F = typeof(AbstractFloat(1))
|
||||
for T in [Int32, Int64]
|
||||
@test isa(param(T(1)), TrackedArray{F, 0})
|
||||
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
|
||||
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
Loading…
Reference in New Issue
Block a user