From 5b37319289dbd0b439d6b53fbbaebba77448b87b Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Wed, 1 Aug 2018 00:10:53 +0800 Subject: [PATCH 01/16] Add Maxpool and Meanpool --- docs/src/models/layers.md | 2 ++ src/layers/conv.jl | 42 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index c2056bb4..070f6737 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks. Chain Dense Conv +Maxpool +Meanpool ``` ## Recurrent Layers diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 38310aad..f074e77f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -50,3 +50,45 @@ function Base.show(io::IO, l::Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end + + +""" + Maxpool(k) + +Maxpooling layer. `k` stands for the size of the window for each dimension of the input. + +Takes the keyword arguments `pad` and `stride`. +""" +struct Maxpool{N} + k::NTuple{N,Int} + pad::NTuple{N,Int} + stride::NTuple{N,Int} + Maxpool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) +end + +(m::Maxpool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::Maxpool) + print(io, "Maxpool(", m.k, ", ", m.pad, ", ", m.stride, ")") +end + + +""" + Meanpool(k) + +Meanpooling layer. `k` stands for the size of the window for each dimension of the input. + +Takes the keyword arguments `pad` and `stride`. +""" +struct Meanpool{N} + k::NTuple{N,Int} + pad::NTuple{N,Int} + stride::NTuple{N,Int} + Meanpool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) +end + +(m::Meanpool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::Meanpool) + print(io, "Meanpool(", m.k, ", ", m.pad, ", ", m.stride, ")") +end From 7bfe4313211c6f38f034a5659932f033e37a0f79 Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Thu, 23 Aug 2018 20:58:58 +0800 Subject: [PATCH 02/16] Fix issue #323 --- src/tracker/array.jl | 6 +++++- src/tracker/scalar.jl | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index ce72755d..35d2c39f 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,4 +1,4 @@ -import Base: *, == +import Base: *, ==, ≈ import LinearAlgebra using Statistics @@ -64,6 +64,10 @@ x::TrackedArray == y = data(x) == y y == x::TrackedArray = y == data(x) x::TrackedArray == y::TrackedArray = data(x) == data(y) +x::TrackedArray ≈ y = data(x) ≈ y +y ≈ x::TrackedArray = y ≈ data(x) +x::TrackedArray ≈ y::TrackedArray = data(x) ≈ data(y) + # Array Stdlib Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 9ff1895a..03892c46 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -32,6 +32,7 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) +Base.:(≈)(x::TrackedReal, y::TrackedReal) = data(x) ≈ data(y) Base.eps(x::TrackedReal) = eps(data(x)) From 634d34686ee2278f61ac62cf7e93d21cfdf6980c Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Fri, 24 Aug 2018 10:31:13 +0800 Subject: [PATCH 03/16] Add new constructors and test --- docs/src/models/layers.md | 4 ++-- src/layers/conv.jl | 36 +++++++++++++++++++++++------------- test/layers/conv.jl | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 test/layers/conv.jl diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 070f6737..4bbb2ba0 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -6,8 +6,8 @@ These core layers form the foundation of almost all neural networks. Chain Dense Conv -Maxpool -Meanpool +MaxPool +MeanPool ``` ## Recurrent Layers diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f074e77f..0f9243ef 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -53,42 +53,52 @@ end """ - Maxpool(k) + MaxPool(k) Maxpooling layer. `k` stands for the size of the window for each dimension of the input. Takes the keyword arguments `pad` and `stride`. """ -struct Maxpool{N} +struct MaxPool{N} k::NTuple{N,Int} pad::NTuple{N,Int} stride::NTuple{N,Int} - Maxpool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) + MaxPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) end -(m::Maxpool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) +function MaxPool{N}(k::Int; pad = 0, stride = k) where N + k_ = Tuple(repeat([k, ], N)) + MaxPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_)) +end -function Base.show(io::IO, m::Maxpool) - print(io, "Maxpool(", m.k, ", ", m.pad, ", ", m.stride, ")") +(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::MaxPool) + print(io, "MaxPool(", m.k, ", ", m.pad, ", ", m.stride, ")") end """ - Meanpool(k) + MeanPool(k) Meanpooling layer. `k` stands for the size of the window for each dimension of the input. Takes the keyword arguments `pad` and `stride`. """ -struct Meanpool{N} +struct MeanPool{N} k::NTuple{N,Int} pad::NTuple{N,Int} stride::NTuple{N,Int} - Meanpool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) + MeanPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) end -(m::Meanpool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) - -function Base.show(io::IO, m::Meanpool) - print(io, "Meanpool(", m.k, ", ", m.pad, ", ", m.stride, ")") +function MeanPool{N}(k::Int; pad = 0, stride = k) where N + k_ = Tuple(repeat([k, ], N)) + MeanPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_)) +end + +(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::MeanPool) + print(io, "MeanPool(", m.k, ", ", m.pad, ", ", m.stride, ")") end diff --git a/test/layers/conv.jl b/test/layers/conv.jl new file mode 100644 index 00000000..2e5e63dd --- /dev/null +++ b/test/layers/conv.jl @@ -0,0 +1,34 @@ +using Test +using Flux: Chain, Conv, MaxPool, MeanPool +using Base.conv + +@testset "pooling" begin + mp = MaxPool((2, 2)) + + @testset "maxpooling" begin + @test MaxPool{2}(2) == mp + @test MaxPool{2}(2; pad=1, stride=3) == MaxPool((2, 2); pad=(1, 1), stride=(3, 3)) + end + + mp = MeanPool((2, 2)) + + @testset "meanpooling" begin + @test MeanPool{2}(2) == mp + @test MeanPool{2}(2; pad=1, stride=3) == MeanPool((2, 2); pad=(1, 1), stride=(3, 3)) + end +end + +@testset "cnn" begin + r = zeros(28, 28) + m = Chain( + Conv((2, 2), 1=>16, relu), + MaxPool{2}(2), + Conv((2, 2), 16=>8, relu), + MaxPool{2}(2), + x -> reshape(x, :, size(x, 4)), + Dense(288, 10), softmax) + + @testset "inference" begin + @test size(m(r)) == (10, ) + end +end From 4ac76c35b0ede5d9c7dc1134f732190543eb499f Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Sat, 25 Aug 2018 14:51:40 +0800 Subject: [PATCH 04/16] =?UTF-8?q?fix=20MethodError=20for=20=3D=3D=20and=20?= =?UTF-8?q?=E2=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```julia param([2]).^2 == [4.0] ERROR: MethodError: ==(::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}) is ambiguous. Candidates: ==(x::TrackedArray, y) in Main.Flux.Tracker at /Users/jc/.julia/dev/Flux/src/tracker/array.jl:63 ==(A::AbstractArray, B::AbstractArray) in Base at abstractarray.jl:1686 Possible fix, define ==(::TrackedArray, ::AbstractArray) ``` --- src/tracker/array.jl | 14 ++++++-------- src/tracker/scalar.jl | 8 +++++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 35d2c39f..923b925c 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,4 +1,4 @@ -import Base: *, ==, ≈ +import Base: * import LinearAlgebra using Statistics @@ -60,13 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -x::TrackedArray == y = data(x) == y -y == x::TrackedArray = y == data(x) -x::TrackedArray == y::TrackedArray = data(x) == data(y) - -x::TrackedArray ≈ y = data(x) ≈ y -y ≈ x::TrackedArray = y ≈ data(x) -x::TrackedArray ≈ y::TrackedArray = data(x) ≈ data(y) +for op in [:(==), :≈] + @eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y) + @eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y)) +end # Array Stdlib diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 03892c46..9e987333 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -30,9 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = error("Not implemented: convert tracked $S to tracked $T") -Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) -Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) -Base.:(≈)(x::TrackedReal, y::TrackedReal) = data(x) ≈ data(y) +for op in [:(==), :≈, :<] + @eval Base.$op(x::TrackedReal, y::Number) = Base.$op(data(x), y) + @eval Base.$op(x::Number, y::TrackedReal) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) +end Base.eps(x::TrackedReal) = eps(data(x)) From 81811a01ce920af9e99cda1d642773afc673fc73 Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Sat, 25 Aug 2018 14:52:08 +0800 Subject: [PATCH 05/16] =?UTF-8?q?Update=20testset=20for=20=3D=3D,=20?= =?UTF-8?q?=E2=89=88,=20and=20 meanpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) -@test (param([1,2,3]) .< 2) == [true, false, false] +@testset "equality & order" begin + # TrackedReal + @test param(2)^2 == param(4) + @test param(2)^2 == 4 + @test 4 == param(2)^2 -@test param(2)^2 == 4.0 + @test param(2)^2 ≈ param(4) + @test param(2)^2 ≈ 4 + @test 4 ≈ param(2)^2 + + @test (param([1,2,3]) .< 2) == [true, false, false] + @test (param([1,2,3]) .<= 2) == [true, true, false] + @test (2 .> param([1,2,3])) == [true, false, false] + @test (2 .>= param([1,2,3])) == [true, true, false] + + # TrackedArray + @test param([1,2,3]).^2 == param([1,4,9]) + @test [1,2,3].^2 == param([1,4,9]) + @test param([1,2,3]).^2 == [1,4,9] + + @test param([1,2,3]).^2 ≈ param([1,4,9]) + @test [1,2,3].^2 ≈ param([1,4,9]) + @test param([1,2,3]).^2 ≈ [1,4,9] +end @testset "reshape" begin x = reshape(param(rand(2,2,2)), 4, 2) From 0c4fb9655a20030c93efd2bac8671d1c55ee2a5d Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Sat, 25 Aug 2018 15:12:01 +0800 Subject: [PATCH 06/16] Fix a bug --- src/tracker/scalar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 9e987333..81ccb9a3 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -31,8 +31,8 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = error("Not implemented: convert tracked $S to tracked $T") for op in [:(==), :≈, :<] - @eval Base.$op(x::TrackedReal, y::Number) = Base.$op(data(x), y) - @eval Base.$op(x::Number, y::TrackedReal) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y) + @eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y)) @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) end From abcefb8ae30fcd745d1ba313a04f68b584fc5879 Mon Sep 17 00:00:00 2001 From: Pietro Vertechi Date: Wed, 29 Aug 2018 18:36:24 +0100 Subject: [PATCH 07/16] fix foldl in tutorial --- docs/src/models/basics.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index da2a125b..88fa0a05 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -172,7 +172,7 @@ using Flux layers = [Dense(10, 5, σ), Dense(5, 2), softmax] -model(x) = foldl((x, m) -> m(x), x, layers) +model(x) = foldl((x, m) -> m(x), layers, init = x) model(rand(10)) # => 2-element vector ``` From a012d0bd513ef7e9ae56c72970aad943b0f1c572 Mon Sep 17 00:00:00 2001 From: Pietro Vertechi Date: Wed, 29 Aug 2018 23:34:41 +0100 Subject: [PATCH 08/16] fix vecnorm in docs --- docs/src/models/regularisation.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index cd53544f..370a53d9 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -1,7 +1,7 @@ # Regularisation Applying regularisation to model parameters is straightforward. We just need to -apply an appropriate regulariser, such as `vecnorm`, to each model parameter and +apply an appropriate regulariser, such as `norm`, to each model parameter and add the result to the overall loss. For example, say we have a simple regression. @@ -15,12 +15,12 @@ loss(x, y) = crossentropy(softmax(m(x)), y) We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`. ```julia -penalty() = vecnorm(m.W) + vecnorm(m.b) +penalty() = norm(m.W) + norm(m.b) loss(x, y) = crossentropy(softmax(m(x)), y) + penalty() ``` When working with layers, Flux provides the `params` function to grab all -parameters at once. We can easily penalise everything with `sum(vecnorm, params)`. +parameters at once. We can easily penalise everything with `sum(norm, params)`. ```julia julia> params(m) @@ -28,7 +28,7 @@ julia> params(m) param([0.355408 0.533092; … 0.430459 0.171498]) param([0.0, 0.0, 0.0, 0.0, 0.0]) -julia> sum(vecnorm, params(m)) +julia> sum(norm, params(m)) 26.01749952921026 (tracked) ``` @@ -40,7 +40,7 @@ m = Chain( Dense(128, 32, relu), Dense(32, 10), softmax) -loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m)) +loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m)) loss(rand(28^2), rand(10)) ``` @@ -57,6 +57,6 @@ julia> activations(c, rand(10)) param([0.0330606, -0.456104]) param([0.61991, 0.38009]) -julia> sum(vecnorm, ans) +julia> sum(norm, ans) 2.639678767773633 (tracked) ``` From 93c4a6b4b5c660956f345d8e0e1871ad880afb8a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 4 Sep 2018 13:37:54 +0100 Subject: [PATCH 09/16] fixes #343 --- REQUIRE | 2 +- src/data/mnist.jl | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/REQUIRE b/REQUIRE index 7164de5a..ad3306d6 100644 --- a/REQUIRE +++ b/REQUIRE @@ -4,7 +4,7 @@ MacroTools 0.3.3 NNlib Requires Adapt -GZip +CodecZlib Colors ZipFile AbstractTrees diff --git a/src/data/mnist.jl b/src/data/mnist.jl index c068bc7d..4397618d 100644 --- a/src/data/mnist.jl +++ b/src/data/mnist.jl @@ -1,11 +1,17 @@ module MNIST -using GZip, Colors +using CodecZlib, Colors const Gray = Colors.Gray{Colors.N0f8} const dir = joinpath(@__DIR__, "../../deps/mnist") +function gzopen(f, file) + open(file) do io + f(GzipDecompressorStream(io)) + end +end + function load() mkpath(dir) cd(dir) do @@ -17,7 +23,7 @@ function load() @info "Downloading MNIST dataset" download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz") open(file, "w") do io - write(io, GZip.open(read, "$file.gz")) + write(io, gzopen(read, "$file.gz")) end end end From 1e0fd07b097f36ff1d70675256ed6e7c7ed66287 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 4 Sep 2018 14:30:02 +0100 Subject: [PATCH 10/16] use `expand` --- src/Flux.jl | 2 +- src/layers/conv.jl | 27 ++++++++++----------------- test/layers/conv.jl | 16 ++++------------ 3 files changed, 15 insertions(+), 30 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 614eeaf7..8c959fec 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -5,7 +5,7 @@ module Flux using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, +export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, Dropout, LayerNorm, BatchNorm, params, mapleaves, cpu, gpu diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5b239751..dbf8ccf9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,6 +1,6 @@ using NNlib: conv -@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2))) +@generated sub2(::Val{N}) where N = :(Val($(N-2))) expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -28,7 +28,7 @@ end Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} = - Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) + Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0, dilation = 1) where N = @@ -55,7 +55,7 @@ end """ MaxPool(k) -Maxpooling layer. `k` stands for the size of the window for each dimension of the input. +Max pooling layer. `k` stands for the size of the window for each dimension of the input. Takes the keyword arguments `pad` and `stride`. """ @@ -63,25 +63,21 @@ struct MaxPool{N} k::NTuple{N,Int} pad::NTuple{N,Int} stride::NTuple{N,Int} - MaxPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) end -function MaxPool{N}(k::Int; pad = 0, stride = k) where N - k_ = Tuple(repeat([k, ], N)) - MaxPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_)) -end +MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = + MaxPool(k, expand(Val(N), pad), expand(Val(N), stride)) (m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) function Base.show(io::IO, m::MaxPool) - print(io, "MaxPool(", m.k, ", ", m.pad, ", ", m.stride, ")") + print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") end - """ MeanPool(k) -Meanpooling layer. `k` stands for the size of the window for each dimension of the input. +Mean pooling layer. `k` stands for the size of the window for each dimension of the input. Takes the keyword arguments `pad` and `stride`. """ @@ -89,16 +85,13 @@ struct MeanPool{N} k::NTuple{N,Int} pad::NTuple{N,Int} stride::NTuple{N,Int} - MeanPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride) end -function MeanPool{N}(k::Int; pad = 0, stride = k) where N - k_ = Tuple(repeat([k, ], N)) - MeanPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_)) -end +MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = + MeanPool(k, expand(Val(N), pad), expand(Val(N), stride)) (m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) function Base.show(io::IO, m::MeanPool) - print(io, "MeanPool(", m.k, ", ", m.pad, ", ", m.stride, ")") + print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 2e5e63dd..07b8c290 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -1,21 +1,13 @@ using Test -using Flux: Chain, Conv, MaxPool, MeanPool +using Flux: Chain, Conv, MaxPool, MeanPool, maxpool, meanpool using Base.conv @testset "pooling" begin + x = randn(10, 10, 3, 2) mp = MaxPool((2, 2)) - - @testset "maxpooling" begin - @test MaxPool{2}(2) == mp - @test MaxPool{2}(2; pad=1, stride=3) == MaxPool((2, 2); pad=(1, 1), stride=(3, 3)) - end - + @test mp(x) == maxpool(x, (2,2)) mp = MeanPool((2, 2)) - - @testset "meanpooling" begin - @test MeanPool{2}(2) == mp - @test MeanPool{2}(2; pad=1, stride=3) == MeanPool((2, 2); pad=(1, 1), stride=(3, 3)) - end + @test mp(x) == meanpool(x, (2,2)) end @testset "cnn" begin From 1e90226077457249af527f69d9fb6018f21dc2e4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 4 Sep 2018 14:35:20 +0100 Subject: [PATCH 11/16] actually run tests --- test/layers/conv.jl | 29 +++++++++++++---------------- test/runtests.jl | 1 + 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 07b8c290..5928bd75 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -1,8 +1,7 @@ -using Test -using Flux: Chain, Conv, MaxPool, MeanPool, maxpool, meanpool -using Base.conv +using Flux, Test +using Flux: maxpool, meanpool -@testset "pooling" begin +@testset "Pooling" begin x = randn(10, 10, 3, 2) mp = MaxPool((2, 2)) @test mp(x) == maxpool(x, (2,2)) @@ -10,17 +9,15 @@ using Base.conv @test mp(x) == meanpool(x, (2,2)) end -@testset "cnn" begin - r = zeros(28, 28) - m = Chain( - Conv((2, 2), 1=>16, relu), - MaxPool{2}(2), - Conv((2, 2), 16=>8, relu), - MaxPool{2}(2), - x -> reshape(x, :, size(x, 4)), - Dense(288, 10), softmax) +@testset "CNN" begin + r = zeros(28, 28, 1, 5) + m = Chain( + Conv((2, 2), 1=>16, relu), + MaxPool((2,2)), + Conv((2, 2), 16=>8, relu), + MaxPool((2,2)), + x -> reshape(x, :, size(x, 4)), + Dense(288, 10), softmax) - @testset "inference" begin - @test size(m(r)) == (10, ) - end + @test size(m(r)) == (10, 5) end diff --git a/test/runtests.jl b/test/runtests.jl index fd48e547..70d929bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ include("onehot.jl") include("tracker.jl") include("layers/normalisation.jl") include("layers/stateless.jl") +include("layers/conv.jl") include("optimise.jl") include("data.jl") From 8b71350878667538fd3024d81dad760c92988b1b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Sep 2018 15:39:00 +0100 Subject: [PATCH 12/16] make travis happy maybe --- test/cuda/cuda.jl | 2 +- test/runtests.jl | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index c9ee95c6..16f90e89 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -1,7 +1,7 @@ using Flux, Flux.Tracker, CuArrays, Test using Flux: gpu -@info "Testing Flux/GPU" +@info "Testing GPU Support" @testset "CuArrays" begin diff --git a/test/runtests.jl b/test/runtests.jl index 70d929bf..7a55dca6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,14 +23,22 @@ insert!(LOAD_PATH, 2, "@v#.#") @testset "Flux" begin +@info "Testing Basics" + include("utils.jl") include("onehot.jl") -include("tracker.jl") +include("optimise.jl") +include("data.jl") + +@info "Testing Layers" + include("layers/normalisation.jl") include("layers/stateless.jl") include("layers/conv.jl") -include("optimise.jl") -include("data.jl") + +@info "Running Gradient Checks" + +include("tracker.jl") if Base.find_package("CuArrays") != nothing include("cuda/cuda.jl") From ec16a2c77dbf6ab8b92b0eecd11661be7a62feef Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Sep 2018 15:55:08 +0100 Subject: [PATCH 13/16] todone: nicer syntax on 0.7 --- src/layers/recurrent.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 4064ed7b..3b40af04 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -128,8 +128,7 @@ function LSTMCell(in::Integer, out::Integer; return cell end -function (m::LSTMCell)(h_, x) - h, c = h_ # TODO: nicer syntax on 0.7 +function (m::LSTMCell)((h, c), x) b, o = m.b, size(h, 1) g = m.Wi*x .+ m.Wh*h .+ b input = σ.(gate(g, o, 1)) From b7eaf393fc5cd3a77a5b5959c25813edec947661 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Sep 2018 16:01:57 +0100 Subject: [PATCH 14/16] docs updates --- docs/src/index.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index afeb2075..d381b194 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,18 +1,17 @@ # Flux: The Julia Machine Learning Library -Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. The whole stack is implemented in clean Julia code (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)) and any part can be tweaked to your liking. +Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles: + +* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast. +* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own. +* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models. # Installation -Install [Julia 0.6.0 or later](https://julialang.org/downloads/), if you haven't already. +Download [Julia 1.0](https://julialang.org/) or later, if you haven't already. You can add Flux from using Julia's package manager, by typing `] add Flux` in the Julia prompt. -```julia -Pkg.add("Flux") -# Optional but recommended -Pkg.update() # Keep your packages up to date -Pkg.test("Flux") # Check things installed correctly -``` +If you have CUDA you can also run `] add CuArrays` to get GPU support; see [here](gpu.md) for more details. -Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. +# Learning Flux -See [GPU support](gpu.md) for more details on installing and using Flux with GPUs. +There are several different ways to learn Flux. If you just want to get started writing models, the [model zoo](https://github.com/FluxML/model-zoo/) gives good starting points for many common ones. This documentation provides a reference to all of Flux's APIs, as well as a from-scratch introduction to Flux's take on models and how they work. Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts. From 193c4ded19290197fb27a4b058cffd34891073b6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Sep 2018 16:52:50 +0100 Subject: [PATCH 15/16] make docs on 1.0 --- .travis.yml | 4 ++-- docs/make.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9bf07dd6..b26597e9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,5 +15,5 @@ matrix: allow_failures: - julia: nightly after_success: - - julia -e 'Pkg.add("Documenter")' - - julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))' + - julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")' + - julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))' diff --git a/docs/make.jl b/docs/make.jl index ed6a8c8b..b35beb3c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -26,6 +26,6 @@ deploydocs( repo = "github.com/FluxML/Flux.jl.git", target = "build", osname = "linux", - julia = "0.6", + julia = "1.0", deps = nothing, make = nothing) From 395a35d137eccc5fc97d43b6b468c50953b77517 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Sep 2018 17:03:41 +0100 Subject: [PATCH 16/16] better headings --- docs/src/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index d381b194..4fc58f72 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,12 +6,12 @@ Flux is a library for machine learning. It comes "batteries-included" with many * **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own. * **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models. -# Installation +## Installation Download [Julia 1.0](https://julialang.org/) or later, if you haven't already. You can add Flux from using Julia's package manager, by typing `] add Flux` in the Julia prompt. If you have CUDA you can also run `] add CuArrays` to get GPU support; see [here](gpu.md) for more details. -# Learning Flux +## Learning Flux There are several different ways to learn Flux. If you just want to get started writing models, the [model zoo](https://github.com/FluxML/model-zoo/) gives good starting points for many common ones. This documentation provides a reference to all of Flux's APIs, as well as a from-scratch introduction to Flux's take on models and how they work. Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts.