Merge branch 'master' into cat-fix

This commit is contained in:
GenaBitu 2017-10-26 00:52:29 +02:00
commit df06c3351d
No known key found for this signature in database
GPG Key ID: 6E647E317A9DD426
23 changed files with 386 additions and 91 deletions

View File

@ -8,7 +8,6 @@ julia:
# uncomment the following lines to override the default test script
script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- julia -e 'Pkg.clone("https://github.com/FluxML/NNlib.jl")'
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
after_success:
- julia -e 'Pkg.add("Documenter")'

View File

@ -1,6 +1,6 @@
using Documenter, Flux
using Documenter, Flux, NNlib
makedocs(modules=[Flux],
makedocs(modules=[Flux, NNlib],
doctest = false,
format = :html,
analytics = "UA-36890222-9",
@ -10,13 +10,13 @@ makedocs(modules=[Flux],
"Building Models" =>
["Basics" => "models/basics.md",
"Recurrence" => "models/recurrence.md",
"Layer Reference" => "models/layers.md"],
"Model Reference" => "models/layers.md"],
"Training Models" =>
["Optimisers" => "training/optimisers.md",
"Training" => "training/training.md"],
"One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md",
"Contributing & Help" => "contributing.md"])
"Community" => "community.md"])
deploydocs(
repo = "github.com/FluxML/Flux.jl.git",

5
docs/src/community.md Normal file
View File

@ -0,0 +1,5 @@
# Community
All Flux users are welcome to join our community on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby). If you have questions or issues we'll try to help you out.
If you're interested in hacking on Flux, the [source code](https://github.com/FluxML/Flux.jl) is open and easy to understand -- it's all just the same Julia code you work with normally. You might be interested in our [intro issues](https://github.com/FluxML/Flux.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) to get started.

View File

@ -1,9 +0,0 @@
# Contributing & Help
If you need help, please ask on the [Julia forum](https://discourse.julialang.org/), the [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) (channel #machine-learning), or Flux's [Gitter](https://gitter.im/FluxML/Lobby).
Right now, the best way to help out is to try out the examples and report any issues or missing features as you find them. The second best way is to help us spread the word, perhaps by [starring the repo](https://github.com/MikeInnes/Flux.jl).
If you're interested in hacking on Flux, most of the [code](https://github.com/MikeInnes/Flux.jl/tree/master/src) is pretty straightforward. Adding new [layer definitions](https://github.com/MikeInnes/Flux.jl/tree/master/src/layers) or cost functions is simple using the Flux DSL itself, and things like data utilities and training processes are all plain Julia code.
If you get stuck or need anything, let us know!

View File

@ -19,16 +19,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapparams`, which allows you to alter all parameters of a model at once.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
```julia
d = Dense(10, 5, σ)
d = mapparams(cu, d)
d = mapleaves(cu, d)
d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapparams(cu, m)
m = mapleaves(cu, m)
d(cu(rand(10)))
```

View File

@ -18,7 +18,7 @@ loss(x, y) # ~ 3
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
```julia
using Flux.Tracker: param, back!, data, grad
using Flux.Tracker
W = param(W)
b = param(b)
@ -31,9 +31,10 @@ back!(l)
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
```julia
grad(W)
W.grad
W.data .-= 0.1grad(W)
# Update the parameter
W.data .-= 0.1(W.grad)
loss(x, y) # ~ 2.5
```

View File

@ -1,6 +1,32 @@
## Model Layers
## Basic Layers
These core layers form the foundation of almost all neural networks.
```@docs
Chain
Dense
```
## Recurrent Layers
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
```@docs
RNN
LSTM
Flux.Recur
```
## Activation Functions
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
```@docs
σ
relu
leakyrelu
elu
swish
```

View File

@ -15,9 +15,9 @@ Recurrent networks introduce a *hidden state* that gets carried over each time w
```julia
h = # ... initial state ...
y₁, h = f(x₁, h)
y₂, h = f(x₂, h)
y₃, h = f(x₃, h)
h, y₁ = f(h, x₁)
h, y₂ = f(h, x₂)
h, y₃ = f(h, x₃)
# ...
```
@ -25,7 +25,7 @@ Information stored in `h` is preserved for the next prediction, allowing it to f
(This might be important if, for example, each `x` represents one word of a sentence; the model's interpretation of the word "bank" should change if the previous input was "river" rather than "investment".)
Flux's RNN support closely follows this mathematical perspective. The most basic RNN is as close as possible to a standard `Dense` layer, and the output and hidden state are the same. By convention, the hidden state is the first input and output.
Flux's RNN support closely follows this mathematical perspective. The most basic RNN is as close as possible to a standard `Dense` layer, and the output is also the hidden state.
```julia
Wxh = randn(5, 10)
@ -112,3 +112,5 @@ truncate!(m)
```
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.
`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.

View File

@ -17,14 +17,11 @@ back!(l)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia
using Flux.Tracker: data, grad
function update()
η = 0.1 # Learning Rate
for p in (W, b)
x, Δ = data(p), grad(p)
x .-= η .* Δ # Apply the update
Δ .= 0 # Clear the gradient
p.data .-= η .* p.grad # Apply the update
p.grad .= 0 # Clear the gradient
end
end
```
@ -48,7 +45,21 @@ For the update step, there's nothing whatsoever wrong with writing the loop abov
```julia
opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
opt()
opt() # Carry out the update, modifying `W` and `b`.
```
An optimiser takes a parameter list and returns a function that does the same thing as `update` above. We can pass either `opt` or `update` to our [training loop](training.md), which will then run the optimiser after every mini-batch of data.
## Optimiser Reference
All optimisers return a function that, when called, will update the parameters passed to it.
```@docs
SGD
Momentum
Nesterov
RMSProp
ADAM
ADAGrad
ADADelta
```

View File

@ -30,7 +30,33 @@ loss(x, y) = Flux.mse(m(x), y)
Flux.train!(loss, data, opt)
```
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `logloss` for cross entropy loss, but you can calculate it however you want.
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
## Datasets
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
```julia
x = rand(784)
y = rand(10)
data = [(x, y)]
```
`Flux.train!` will call `loss(x, y)`, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:
```julia
data = [(x, y), (x, y), (x, y)]
# Or equivalently
data = Iterators.repeated((x, y), 3)
```
It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
```julia
xs = [rand(784), rand(784), rand(784)]
ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)
```
## Callbacks

View File

@ -8,10 +8,11 @@ using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM,
SGD, params, mapparams
SGD, ADAM, Momentum, Nesterov,
param, params, mapleaves
using NNlib
export σ, relu, softmax
export σ, relu, leakyrelu, elu, swish, softmax
include("tracker/Tracker.jl")
using .Tracker

View File

@ -4,12 +4,14 @@
Chain multiple layers / functions together, so that they are called in sequence
on a given input.
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
```julia
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
m = Chain(Dense(10, 5), Dense(5, 2))
x = rand(10)
m(x) == m[2](m[1](x))
m = Chain(Dense(10, 5), Dense(5, 2))
x = rand(10)
m(x) == m[2](m[1](x))
```
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers.
@ -43,7 +45,17 @@ Creates a traditional `Dense` layer with parameters `W` and `b`.
y = σ.(W * x .+ b)
The input `x` must be a vector of length `in`, or a batch of vectors represented
as an `in × N` matrix. The out `y` will be a vector or batch of length `in`.
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
```julia
julia> d = Dense(5, 2)
Dense(5, 2)
julia> d(rand(5))
Tracked 2-element Array{Float64,1}:
0.00257447
-0.00449443
```
"""
struct Dense{F,S,T}
σ::F

View File

@ -3,12 +3,33 @@ combine(x, h) = vcat(x, h .* trues(1, size(x, 2)))
# Stateful recurrence
"""
Recur(cell)
`Recur` takes a recurrent cell and makes it stateful, managing the hidden state
in the background. `cell` should be a model of the form:
h, y = cell(h, x...)
For example, here's a recurrent network that keeps a running total of its inputs.
```julia
accum(h, x) = (h+x, x)
rnn = Flux.Recur(accum, 0)
rnn(2) # 2
rnn(3) # 3
rnn.state # 5
rnn.(1:10) # apply to a sequence
rnn.state # 60
```
"""
mutable struct Recur{T}
cell::T
init
state
end
Recur(m) = Recur(m, hidden(m))
Recur(m, h = hidden(m)) = Recur(m, h, h)
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
@ -20,12 +41,34 @@ treelike(Recur)
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
_truncate(x::AbstractArray) = x
_truncate(x::TrackedArray) = x.data
_truncate(x::AbstractArray) = Tracker.data(x)
_truncate(x::Tuple) = _truncate.(x)
truncate!(m) = foreach(truncate!, children(m))
truncate!(m::Recur) = (m.state = _truncate(m.state))
"""
truncate!(rnn)
Truncates the gradient of the hidden state in recurrent layers. The value of the
state is preserved. See also `reset!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = Tracker.data(rnn.state)
"""
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
"""
reset!(rnn)
Reset the hidden state of a recurrent layer back to its original value. See also
`truncate!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
"""
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
flip(f, xs) = reverse(f.(reverse(xs)))
# Vanilla RNN
@ -50,6 +93,12 @@ function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
"""
RNN(in::Integer, out::Integer, σ = tanh)
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM
@ -89,4 +138,13 @@ Base.show(io::IO, m::LSTMCell) =
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
"""
LSTM(in::Integer, out::Integer, σ = tanh)
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))

View File

@ -1,7 +1,14 @@
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
# back!(::typeof(mse), Δ, ŷ, y) = Δ .* (ŷ .- y)
logloss(, y) = -sum(y .* log.()) / size(y, 2)
# back!(::typeof(logloss), Δ, ŷ, y) = 0 .- Δ .* y ./ ŷ
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) =
-sum(y .* log.()) / size(y, 2)
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
end

View File

@ -20,7 +20,9 @@ Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import NNlib.adapt
@ -32,7 +34,12 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
function onehot(l, labels)
i = findfirst(labels, l)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) =

View File

@ -1,7 +1,7 @@
module Optimise
export update!, params, train!,
SGD
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta
struct Param{T}
x::T
@ -16,6 +16,6 @@ include("train.jl")
using Flux.Tracker: TrackedArray
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end

View File

@ -9,10 +9,65 @@ function optimiser(ps, fs...)
() -> foreach(call, fs)
end
SGD(ps, η = 1) = optimiser(ps, p -> descent(p, η))
ADAM(ps, η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0.0) = optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
Momentum(ps,ρ, decay = 0.0) = optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
Nesterov(ps,ρ, decay = 0.0) = optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
RMSProp(ps, η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
ADAGrad(ps, η = 0.01, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
ADADelta(ps, η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
SGD(params, η = 1; decay = 0)
Classic gradient descent optimiser. For each parameter `p` and its
gradient `δp`, this runs `p -= η*δp`.
Supports decayed learning rate decay if the `decay` argument is provided.
"""
SGD(ps, η = 1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η))
"""
Momentum(params, ρ, decay = 0)
SGD with momentum `ρ` and optional learning rate decay.
"""
Momentum(ps, ρ; decay = 0) =
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
Nesterov(params, ρ, decay = 0)
SGD with Nesterov momentum `ρ` and optional learning rate decay.
"""
Nesterov(ps, ρ; decay = 0) =
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

View File

@ -4,10 +4,19 @@ using Flux.Tracker: back!
tocb(f) = f
tocb(fs::AbstractVector) = () -> foreach(call, fs)
function train!(m, data, opt; cb = () -> ())
"""
train!(loss, data, opt; cb = () -> ())
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt` and the callback `cb`
(i.e. `opt()` and `cb()`).
Multiple callbacks can be passed to `cb` as an array.
"""
function train!(loss, data, opt; cb = () -> ())
cb = tocb(cb)
@progress for x in data
l = m(x...)
@progress for d in data
l = loss(d...)
isinf(l.data[]) && error("Loss is Inf")
isnan(l.data[]) && error("Loss is NaN")
back!(l)

View File

@ -1,7 +1,5 @@
module Tracker
using Base: RefValue
export TrackedArray, param, back!
data(x) = x
@ -16,11 +14,13 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...)
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
ref::RefValue{UInt32}
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
ref::UInt32
f::Call
data::A
grad::RefValue{A}
grad::A
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
end
TrackedScalar{T,A} = TrackedArray{T,0,A}
@ -28,19 +28,22 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x)
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs))
param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad[]
grad(x::TrackedArray) = x.grad
# Fallthrough methods
@ -73,8 +76,6 @@ include("numeric.jl")
import NNlib.adapt
adapt(T, xs::TrackedArray) =
TrackedArray(xs.f, adapt(T, xs.data),
RefValue(adapt(T, grad(xs))))
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
end

View File

@ -3,11 +3,11 @@ scan(x) = nothing
scan(c::Call) = foreach(scan, c.args)
function scan(x::TrackedArray)
ref = x.ref[] += 1
ref = x.ref += 1
if ref == 1
scan(x.f)
else
isassigned(x.grad) || (x.grad[] = zeros(x.data))
isdefined(x, :grad) || (x.grad = zeros(x.data))
end
return
end
@ -16,10 +16,10 @@ back(c::Call, Δ) = back(c.func, Δ, c.args...)
back(::Call{Void}, Δ) = nothing
function back(x::TrackedArray, Δ)
ref = x.ref[] -= 1
if isassigned(x.grad)
x.grad[] .+= Δ
ref == 0 && back(x.f, x.grad[])
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
ref == 0 && back(x.f, x.grad)
else
ref == 0 && back(x.f, Δ)
end

View File

@ -8,18 +8,27 @@ function treelike(T, fs = fieldnames(T))
end
end
# TODO: prewalk/postwalk with correct caching
# This is only correct in general for idempotent functions
isleaf(x) = isempty(children(x))
mapparams(f, x::AbstractArray) = f(x)
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
function mapleaves(f, x; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
export mapparams
@deprecate mapparams(f, x) mapleaves(f, x)
using DataFlow: OSet
function params(m)
ps = OSet()
forparams(p -> push!(ps, p), m)
return collect(ps)
function prefor(f, x; seen = OSet())
x seen && return
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))
return
end
function params(m)
ps = []
prefor(p -> p isa TrackedArray && push!(ps, p), m)
return ps
end

View File

@ -9,6 +9,67 @@ unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
"""
chunk(xs, n)
Split `xs` into `n` parts.
```julia
julia> chunk(1:10, 3)
3-element Array{Array{Int64,1},1}:
[1, 2, 3, 4]
[5, 6, 7, 8]
[9, 10]
```
"""
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i)
"""
batch(xs)
Batch the arrays in `xs` into a single array.
```julia
julia> batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}:
1 4
2 5
3 6
```
"""
function batch(xs)
data = first(xs) isa AbstractArray ?
similar(first(xs), size(first(xs))..., length(xs)) :
Vector{eltype(xs)}(length(xs))
for (i, x) in enumerate(xs)
data[batchindex(data, i)...] = x
end
return data
end
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
"""
batchseq(seqs, pad)
Take a list of `N` sequences, and turn them into a single sequence where each
item is a batch of `N`. Short sequences will be padded by `pad`.
```julia
julia> batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}:
[1, 4]
[2, 5]
[3, 0]
```
"""
function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
xs_ = [rpad(x, n, pad) for x in xs]
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
end
# Other
function accuracy(m, data)

View File

@ -15,7 +15,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
@test gradtest(Flux.logloss, rand(5,5), rand(5, 5))
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5))
@ -27,4 +27,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
2y + x
end
for T in [Float32, Float64]
@test isa(param(T(1)), TrackedArray{T, 0})
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
end
# TODO: do we wand this behaviour ??
F = typeof(AbstractFloat(1))
for T in [Int32, Int64]
@test isa(param(T(1)), TrackedArray{F, 0})
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
end
end #testset