Merge branch 'master' into issue-#354

This commit is contained in:
Johnny Chen 2018-09-06 09:39:31 -05:00 committed by GitHub
commit 44049ce00c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 242 additions and 91 deletions

View File

@ -15,5 +15,5 @@ matrix:
allow_failures:
- julia: nightly
after_success:
- julia -e 'Pkg.add("Documenter")'
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
- julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")'
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'

View File

@ -4,7 +4,7 @@ MacroTools 0.3.3
NNlib
Requires
Adapt
GZip
CodecZlib
Colors
ZipFile
AbstractTrees

View File

@ -26,6 +26,6 @@ deploydocs(
repo = "github.com/FluxML/Flux.jl.git",
target = "build",
osname = "linux",
julia = "0.6",
julia = "1.0",
deps = nothing,
make = nothing)

View File

@ -3,7 +3,7 @@
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
```
julia> using Flux: onehot
julia> using Flux: onehot, onecold
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
true
```
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
```julia
julia> argmax(ans, [:a, :b, :c])
julia> onecold(ans, [:a, :b, :c])
:c
julia> argmax([true, false, false], [:a, :b, :c])
julia> onecold([true, false, false], [:a, :b, :c])
:a
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
## Batches
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
```julia
julia> using Flux: onehotbatch

View File

@ -1,18 +1,17 @@
# Flux: The Julia Machine Learning Library
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. The whole stack is implemented in clean Julia code (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)) and any part can be tweaked to your liking.
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles:
# Installation
* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work and be fast.
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, its well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models.
Install [Julia 0.6.0 or later](https://julialang.org/downloads/), if you haven't already.
## Installation
```julia
Pkg.add("Flux")
# Optional but recommended
Pkg.update() # Keep your packages up to date
Pkg.test("Flux") # Check things installed correctly
```
Download [Julia 1.0](https://julialang.org/) or later, if you haven't already. You can add Flux from using Julia's package manager, by typing `] add Flux` in the Julia prompt.
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
If you have CUDA you can also run `] add CuArrays` to get GPU support; see [here](gpu.md) for more details.
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.
## Learning Flux
There are several different ways to learn Flux. If you just want to get started writing models, the [model zoo](https://github.com/FluxML/model-zoo/) gives good starting points for many common ones. This documentation provides a reference to all of Flux's APIs, as well as a from-scratch introduction to Flux's take on models and how they work. Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts.

View File

@ -172,7 +172,7 @@ using Flux
layers = [Dense(10, 5, σ), Dense(5, 2), softmax]
model(x) = foldl((x, m) -> m(x), x, layers)
model(x) = foldl((x, m) -> m(x), layers, init = x)
model(rand(10)) # => 2-element vector
```

View File

@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks.
Chain
Dense
Conv
MaxPool
MeanPool
```
## Recurrent Layers

View File

@ -1,7 +1,7 @@
# Regularisation
Applying regularisation to model parameters is straightforward. We just need to
apply an appropriate regulariser, such as `vecnorm`, to each model parameter and
apply an appropriate regulariser, such as `norm`, to each model parameter and
add the result to the overall loss.
For example, say we have a simple regression.
@ -15,12 +15,12 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
```julia
penalty() = vecnorm(m.W) + vecnorm(m.b)
penalty() = norm(m.W) + norm(m.b)
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
```
When working with layers, Flux provides the `params` function to grab all
parameters at once. We can easily penalise everything with `sum(vecnorm, params)`.
parameters at once. We can easily penalise everything with `sum(norm, params)`.
```julia
julia> params(m)
@ -28,7 +28,7 @@ julia> params(m)
param([0.355408 0.533092; … 0.430459 0.171498])
param([0.0, 0.0, 0.0, 0.0, 0.0])
julia> sum(vecnorm, params(m))
julia> sum(norm, params(m))
26.01749952921026 (tracked)
```
@ -40,7 +40,7 @@ m = Chain(
Dense(128, 32, relu),
Dense(32, 10), softmax)
loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m))
loss(rand(28^2), rand(10))
```
@ -57,6 +57,6 @@ julia> activations(c, rand(10))
param([0.0330606, -0.456104])
param([0.61991, 0.38009])
julia> sum(vecnorm, ans)
julia> sum(norm, ans)
2.639678767773633 (tracked)
```

View File

@ -5,7 +5,7 @@ module Flux
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv,
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu

View File

@ -1,11 +1,17 @@
module MNIST
using GZip, Colors
using CodecZlib, Colors
const Gray = Colors.Gray{Colors.N0f8}
const dir = joinpath(@__DIR__, "../../deps/mnist")
function gzopen(f, file)
open(file) do io
f(GzipDecompressorStream(io))
end
end
function load()
mkpath(dir)
cd(dir) do
@ -17,7 +23,7 @@ function load()
@info "Downloading MNIST dataset"
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
open(file, "w") do io
write(io, GZip.open(read, "$file.gz"))
write(io, gzopen(read, "$file.gz"))
end
end
end

View File

@ -1,6 +1,6 @@
using NNlib: conv
@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2)))
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
@ -28,7 +28,7 @@ end
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N =
@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
MaxPool(k)
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
"""
struct MaxPool{N}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end
"""
MeanPool(k)
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
"""
struct MeanPool{N}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
end
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end

View File

@ -128,8 +128,7 @@ function LSTMCell(in::Integer, out::Integer;
return cell
end
function (m::LSTMCell)(h_, x)
h, c = h_ # TODO: nicer syntax on 0.7
function (m::LSTMCell)((h, c), x)
b, o = m.b, size(h, 1)
g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1))

View File

@ -54,11 +54,15 @@ end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))]
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
argmax(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...)
end
# Ambiguity hack

View File

@ -2,7 +2,7 @@ module Optimise
export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
struct Param{T}
x::T

View File

@ -1,5 +1,6 @@
using Juno
using Flux.Tracker: back!
import Base.depwarn
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -14,6 +15,25 @@ macro interrupts(ex)
end)
end
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
```julia
# Example callback:
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
throw(StopException())
end
"""
train!(loss, data, opt)
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
@progress for d in data
l = loss(d...)
@interrupts back!(l)
opt()
cb() == :stop && break
try
l = loss(d...)
@interrupts back!(l)
opt()
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
end
catch ex
if ex isa StopException
break
else
rethrow(ex)
end
end
end
end

View File

@ -1,4 +1,4 @@
import Base: *, ==
import Base: *
import LinearAlgebra
using Statistics
@ -48,7 +48,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
for f in :[Base.size, Base.ndims, Base.collect].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
@ -60,9 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
x::TrackedArray == y = data(x) == y
y == x::TrackedArray = y == data(x)
x::TrackedArray == y::TrackedArray = data(x) == data(y)
for op in [:(==), :≈]
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
end
# Array Stdlib
@ -286,15 +288,6 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
# @grad function (a::AbstractMatrix * b::AbstractVecOrMat)
# # @show size(a) size(b)
# data(a)*data(b), function (Δ)
# @show size(Δ) size(b) size(Δ*transpose(b)) size(Δ*transpose(data(b)))
# @show typeof(Δ) typeof(b)
# (Δ * transpose(b), transpose(a) * Δ)
# end
# end
# NNlib
using NNlib
@ -336,35 +329,33 @@ end
using ForwardDiff: Dual, partials, value
_size(x::AbstractArray) = size(x)
_size(x) = ()
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
dualify(xs, n) = xs
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
dualify(xs::Real, ps) = Dual(xs, ps)
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Tuple, Δ) =
x == size(Δ) ? Δ :
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
function getpartial(Δ, x, i)
@inbounds p = getindex(partials(x), i)
return Δ * p
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
return Δ * f(dargs...).partials[1]
end
function ∇broadcast(f, args::Vararg{Any,N}) where N
sizes = _size.(args)
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N)))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
y = value.(out)
back = function (Δ_)
Δ = data(Δ_)
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
nobacksies(:broadcast, dxs)
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
y = broadcast(f, data.(args)...)
eltype(y) <: Real || return y
eltype(y) == Bool && return y
function back(Δ)
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
dxs = unbroadcast.(args, Δargs)
return nobacksies(:broadcast, dxs)
end
# So we can return non-tracked arrays
track(Call(back, tracker.(args)), y)
@ -395,7 +386,7 @@ end
using Requires
# https://github.com/FluxML/Flux.jl/issues/353
@init @eval Base.Broadcast begin
@init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)

View File

@ -30,8 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
for op in [:(==), :≈, :<]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
Base.eps(x::TrackedReal) = eps(data(x))

View File

@ -1,7 +1,7 @@
using Flux, Flux.Tracker, CuArrays, Test
using Flux: gpu
@info "Testing Flux/GPU"
@info "Testing GPU Support"
@testset "CuArrays" begin
@ -26,6 +26,10 @@ x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
xs = param(rand(5,5))
ys = Flux.onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))

23
test/layers/conv.jl Normal file
View File

@ -0,0 +1,23 @@
using Flux, Test
using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2))
end
@testset "CNN" begin
r = zeros(28, 28, 1, 5)
m = Chain(
Conv((2, 2), 1=>16, relu),
MaxPool((2,2)),
Conv((2, 2), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax)
@test size(m(r)) == (10, 5)
end

13
test/onehot.jl Normal file
View File

@ -0,0 +1,13 @@
using Flux:onecold
using Test
@testset "onecold" begin
a = [1, 2, 5, 3.]
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
labels = ['A', 'B', 'C', 'D']
@test onecold(a) == 3
@test onecold(A) == [3, 1, 4]
@test onecold(a, labels) == 'C'
@test onecold(A, labels) == ['C', 'A', 'D']
end

View File

@ -23,7 +23,7 @@ end
Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100),
()->(),
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
@test 3 < i < 50
end

View File

@ -23,13 +23,23 @@ insert!(LOAD_PATH, 2, "@v#.#")
@testset "Flux" begin
@info "Testing Basics"
include("utils.jl")
include("tracker.jl")
include("onehot.jl")
include("optimise.jl")
include("data.jl")
@info "Testing Layers"
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")
include("data.jl")
include("layers/conv.jl")
@info "Running Gradient Checks"
include("tracker.jl")
if Base.find_package("CuArrays") != nothing
include("cuda/cuda.jl")

View File

@ -182,9 +182,30 @@ end
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test (param([1,2,3]) .< 2) == [true, false, false]
@testset "equality & order" begin
# TrackedReal
@test param(2)^2 == param(4)
@test param(2)^2 == 4
@test 4 == param(2)^2
@test param(2)^2 == 4.0
@test param(2)^2 param(4)
@test param(2)^2 4
@test 4 param(2)^2
@test (param([1,2,3]) .< 2) == [true, false, false]
@test (param([1,2,3]) .<= 2) == [true, true, false]
@test (2 .> param([1,2,3])) == [true, false, false]
@test (2 .>= param([1,2,3])) == [true, true, false]
# TrackedArray
@test param([1,2,3]).^2 == param([1,4,9])
@test [1,2,3].^2 == param([1,4,9])
@test param([1,2,3]).^2 == [1,4,9]
@test param([1,2,3]).^2 param([1,4,9])
@test [1,2,3].^2 param([1,4,9])
@test param([1,2,3]).^2 [1,4,9]
end
@testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2)