diff --git a/.travis.yml b/.travis.yml index 9bf07dd6..b26597e9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,5 +15,5 @@ matrix: allow_failures: - julia: nightly after_success: - - julia -e 'Pkg.add("Documenter")' - - julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))' + - julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")' + - julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))' diff --git a/REQUIRE b/REQUIRE index 7164de5a..ad3306d6 100644 --- a/REQUIRE +++ b/REQUIRE @@ -4,7 +4,7 @@ MacroTools 0.3.3 NNlib Requires Adapt -GZip +CodecZlib Colors ZipFile AbstractTrees diff --git a/docs/make.jl b/docs/make.jl index ed6a8c8b..b35beb3c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -26,6 +26,6 @@ deploydocs( repo = "github.com/FluxML/Flux.jl.git", target = "build", osname = "linux", - julia = "0.6", + julia = "1.0", deps = nothing, make = nothing) diff --git a/docs/src/index.md b/docs/src/index.md index afeb2075..4fc58f72 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,18 +1,17 @@ # Flux: The Julia Machine Learning Library -Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. The whole stack is implemented in clean Julia code (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)) and any part can be tweaked to your liking. +Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles: -# Installation +* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast. +* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own. +* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models. -Install [Julia 0.6.0 or later](https://julialang.org/downloads/), if you haven't already. +## Installation -```julia -Pkg.add("Flux") -# Optional but recommended -Pkg.update() # Keep your packages up to date -Pkg.test("Flux") # Check things installed correctly -``` +Download [Julia 1.0](https://julialang.org/) or later, if you haven't already. You can add Flux from using Julia's package manager, by typing `] add Flux` in the Julia prompt. -Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. +If you have CUDA you can also run `] add CuArrays` to get GPU support; see [here](gpu.md) for more details. -See [GPU support](gpu.md) for more details on installing and using Flux with GPUs. +## Learning Flux + +There are several different ways to learn Flux. If you just want to get started writing models, the [model zoo](https://github.com/FluxML/model-zoo/) gives good starting points for many common ones. This documentation provides a reference to all of Flux's APIs, as well as a from-scratch introduction to Flux's take on models and how they work. Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts. diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index da2a125b..88fa0a05 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -172,7 +172,7 @@ using Flux layers = [Dense(10, 5, σ), Dense(5, 2), softmax] -model(x) = foldl((x, m) -> m(x), x, layers) +model(x) = foldl((x, m) -> m(x), layers, init = x) model(rand(10)) # => 2-element vector ``` diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index c2056bb4..4bbb2ba0 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks. Chain Dense Conv +MaxPool +MeanPool ``` ## Recurrent Layers diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index cd53544f..370a53d9 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -1,7 +1,7 @@ # Regularisation Applying regularisation to model parameters is straightforward. We just need to -apply an appropriate regulariser, such as `vecnorm`, to each model parameter and +apply an appropriate regulariser, such as `norm`, to each model parameter and add the result to the overall loss. For example, say we have a simple regression. @@ -15,12 +15,12 @@ loss(x, y) = crossentropy(softmax(m(x)), y) We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`. ```julia -penalty() = vecnorm(m.W) + vecnorm(m.b) +penalty() = norm(m.W) + norm(m.b) loss(x, y) = crossentropy(softmax(m(x)), y) + penalty() ``` When working with layers, Flux provides the `params` function to grab all -parameters at once. We can easily penalise everything with `sum(vecnorm, params)`. +parameters at once. We can easily penalise everything with `sum(norm, params)`. ```julia julia> params(m) @@ -28,7 +28,7 @@ julia> params(m) param([0.355408 0.533092; … 0.430459 0.171498]) param([0.0, 0.0, 0.0, 0.0, 0.0]) -julia> sum(vecnorm, params(m)) +julia> sum(norm, params(m)) 26.01749952921026 (tracked) ``` @@ -40,7 +40,7 @@ m = Chain( Dense(128, 32, relu), Dense(32, 10), softmax) -loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m)) +loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m)) loss(rand(28^2), rand(10)) ``` @@ -57,6 +57,6 @@ julia> activations(c, rand(10)) param([0.0330606, -0.456104]) param([0.61991, 0.38009]) -julia> sum(vecnorm, ans) +julia> sum(norm, ans) 2.639678767773633 (tracked) ``` diff --git a/src/Flux.jl b/src/Flux.jl index 6ec849b0..e8ca9d75 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -5,7 +5,7 @@ module Flux using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, +export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, Dropout, LayerNorm, BatchNorm, params, mapleaves, cpu, gpu diff --git a/src/data/mnist.jl b/src/data/mnist.jl index c068bc7d..4397618d 100644 --- a/src/data/mnist.jl +++ b/src/data/mnist.jl @@ -1,11 +1,17 @@ module MNIST -using GZip, Colors +using CodecZlib, Colors const Gray = Colors.Gray{Colors.N0f8} const dir = joinpath(@__DIR__, "../../deps/mnist") +function gzopen(f, file) + open(file) do io + f(GzipDecompressorStream(io)) + end +end + function load() mkpath(dir) cd(dir) do @@ -17,7 +23,7 @@ function load() @info "Downloading MNIST dataset" download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz") open(file, "w") do io - write(io, GZip.open(read, "$file.gz")) + write(io, gzopen(read, "$file.gz")) end end end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 78509c84..dbf8ccf9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,6 +1,6 @@ using NNlib: conv -@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2))) +@generated sub2(::Val{N}) where N = :(Val($(N-2))) expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -28,7 +28,7 @@ end Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} = - Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) + Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0, dilation = 1) where N = @@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end + + +""" + MaxPool(k) + +Max pooling layer. `k` stands for the size of the window for each dimension of the input. + +Takes the keyword arguments `pad` and `stride`. +""" +struct MaxPool{N} + k::NTuple{N,Int} + pad::NTuple{N,Int} + stride::NTuple{N,Int} +end + +MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = + MaxPool(k, expand(Val(N), pad), expand(Val(N), stride)) + +(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::MaxPool) + print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") +end + +""" + MeanPool(k) + +Mean pooling layer. `k` stands for the size of the window for each dimension of the input. + +Takes the keyword arguments `pad` and `stride`. +""" +struct MeanPool{N} + k::NTuple{N,Int} + pad::NTuple{N,Int} + stride::NTuple{N,Int} +end + +MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = + MeanPool(k, expand(Val(N), pad), expand(Val(N), stride)) + +(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::MeanPool) + print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 4064ed7b..3b40af04 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -128,8 +128,7 @@ function LSTMCell(in::Integer, out::Integer; return cell end -function (m::LSTMCell)(h_, x) - h, c = h_ # TODO: nicer syntax on 0.7 +function (m::LSTMCell)((h, c), x) b, o = m.b, size(h, 1) g = m.Wi*x .+ m.Wh*h .+ b input = σ.(gate(g, o, 1)) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index ffa3a89e..882a866c 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,4 +1,4 @@ -import Base: *, == +import Base: * import LinearAlgebra using Statistics @@ -60,9 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -x::TrackedArray == y = data(x) == y -y == x::TrackedArray = y == data(x) -x::TrackedArray == y::TrackedArray = data(x) == data(y) +for op in [:(==), :≈] + @eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y) + @eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y)) +end # Array Stdlib diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 9ff1895a..81ccb9a3 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -30,8 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = error("Not implemented: convert tracked $S to tracked $T") -Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) -Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) +for op in [:(==), :≈, :<] + @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y) + @eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) +end Base.eps(x::TrackedReal) = eps(data(x)) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index c9ee95c6..16f90e89 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -1,7 +1,7 @@ using Flux, Flux.Tracker, CuArrays, Test using Flux: gpu -@info "Testing Flux/GPU" +@info "Testing GPU Support" @testset "CuArrays" begin diff --git a/test/layers/conv.jl b/test/layers/conv.jl new file mode 100644 index 00000000..5928bd75 --- /dev/null +++ b/test/layers/conv.jl @@ -0,0 +1,23 @@ +using Flux, Test +using Flux: maxpool, meanpool + +@testset "Pooling" begin + x = randn(10, 10, 3, 2) + mp = MaxPool((2, 2)) + @test mp(x) == maxpool(x, (2,2)) + mp = MeanPool((2, 2)) + @test mp(x) == meanpool(x, (2,2)) +end + +@testset "CNN" begin + r = zeros(28, 28, 1, 5) + m = Chain( + Conv((2, 2), 1=>16, relu), + MaxPool((2,2)), + Conv((2, 2), 16=>8, relu), + MaxPool((2,2)), + x -> reshape(x, :, size(x, 4)), + Dense(288, 10), softmax) + + @test size(m(r)) == (10, 5) +end diff --git a/test/runtests.jl b/test/runtests.jl index fd48e547..7a55dca6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,14 +23,23 @@ insert!(LOAD_PATH, 2, "@v#.#") @testset "Flux" begin +@info "Testing Basics" + include("utils.jl") include("onehot.jl") -include("tracker.jl") -include("layers/normalisation.jl") -include("layers/stateless.jl") include("optimise.jl") include("data.jl") +@info "Testing Layers" + +include("layers/normalisation.jl") +include("layers/stateless.jl") +include("layers/conv.jl") + +@info "Running Gradient Checks" + +include("tracker.jl") + if Base.find_package("CuArrays") != nothing include("cuda/cuda.jl") end diff --git a/test/tracker.jl b/test/tracker.jl index 03d14c35..9a4cb793 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -182,9 +182,30 @@ end @test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) -@test (param([1,2,3]) .< 2) == [true, false, false] +@testset "equality & order" begin + # TrackedReal + @test param(2)^2 == param(4) + @test param(2)^2 == 4 + @test 4 == param(2)^2 -@test param(2)^2 == 4.0 + @test param(2)^2 ≈ param(4) + @test param(2)^2 ≈ 4 + @test 4 ≈ param(2)^2 + + @test (param([1,2,3]) .< 2) == [true, false, false] + @test (param([1,2,3]) .<= 2) == [true, true, false] + @test (2 .> param([1,2,3])) == [true, false, false] + @test (2 .>= param([1,2,3])) == [true, true, false] + + # TrackedArray + @test param([1,2,3]).^2 == param([1,4,9]) + @test [1,2,3].^2 == param([1,4,9]) + @test param([1,2,3]).^2 == [1,4,9] + + @test param([1,2,3]).^2 ≈ param([1,4,9]) + @test [1,2,3].^2 ≈ param([1,4,9]) + @test param([1,2,3]).^2 ≈ [1,4,9] +end @testset "reshape" begin x = reshape(param(rand(2,2,2)), 4, 2)