From 930ebaf2170a77e3f3117b94f3810f2f47f751b2 Mon Sep 17 00:00:00 2001 From: Josh Whittemore Date: Wed, 6 Feb 2019 16:17:59 -0800 Subject: [PATCH 01/23] Add module to make iris dataset available. --- Project.toml | 1 + src/data/Data.jl | 3 ++ src/data/iris.jl | 88 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 src/data/iris.jl diff --git a/Project.toml b/Project.toml index 331d6839..71ff4877 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" diff --git a/src/data/Data.jl b/src/data/Data.jl index ab78f416..d7cd0303 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -39,4 +39,7 @@ include("tree.jl") include("sentiment.jl") using .Sentiment +include("iris.jl") +export Iris + end diff --git a/src/data/iris.jl b/src/data/iris.jl new file mode 100644 index 00000000..c432f847 --- /dev/null +++ b/src/data/iris.jl @@ -0,0 +1,88 @@ + +""" + + Iris + +Fisher's classic iris dataset. + +Measurements from 3 different species of iris: setosa, versicolor and +virginica. There are 50 examples of each species. + +There are 4 measurements for each example: sepal length, sepal width, petal +length and petal width. The measurements are in centimeters. + +The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris). + +""" +module Iris + +using DelimitedFiles +using ..Data: deps, download_and_verify + +const cache_prefix = "" + +# Uncomment if the iris.data file is cached to cache.julialang.org. +# const cache_prefix = "https://cache.julialang.org/" + +function load() + isfile(deps("iris.data")) && return + + @info "Downloading iris dataset." + download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data", + deps("iris.data"), + "6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0") +end + +""" + + labels() + +Get the labels of the iris dataset, a 150 element array of strings listing the +species of each example. + +```jldoctest +julia> labels = Flux.Data.Iris.labels(); + +julia> summary(labels) +"150-element Array{String,1}" + +julia> labels[1] +"Iris-setosa" +``` +""" +function labels() + load() + iris = readdlm(deps("iris.data"), ',') + Vector{String}(iris[1:end, end]) +end + +""" + + features() + +Get the features of the iris dataset. This is a 4x150 matrix of Float64 +elements. It has a row for each feature (sepal length, sepal width, +petal length, petal width) and a column for each example. + +```jldoctest +julia> features = Flux.Data.Iris.features(); + +julia> summary(features) +"4×150 Array{Float64,2}" + +julia> features[:, 1] +4-element Array{Float64,1}: + 5.1 + 3.5 + 1.4 + 0.2 +``` +""" +function features() + load() + iris = readdlm(deps("iris.data"), ',') + Matrix{Float64}(iris[1:end, 1:4]') +end +end + + From 0cac3735398ad2d158e3fee24be328be2096d150 Mon Sep 17 00:00:00 2001 From: Joshua Whittemore Date: Sat, 9 Mar 2019 13:02:59 -0800 Subject: [PATCH 02/23] add tests for Data.Iris module --- test/data.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/data.jl b/test/data.jl index a73d1ec3..6b777873 100644 --- a/test/data.jl +++ b/test/data.jl @@ -14,3 +14,9 @@ using Test @test FashionMNIST.labels() isa Vector{Int64} @test Data.Sentiment.train() isa Vector{Data.Tree{Any}} + +@test Iris.features() isa Matrix +@test size(Iris.features()) == (4,150) + +@test Iris.labels() isa Vector{String} +@test size(Iris.labels()) == (150,) From 61588f72eff4fb4c2d796d03c9c0b6727c82e749 Mon Sep 17 00:00:00 2001 From: Joshua Whittemore Date: Sat, 9 Mar 2019 13:20:35 -0800 Subject: [PATCH 03/23] add item to NEWS.md describing Data.Iris module --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 6f4758fe..8be589a2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,6 +9,7 @@ * New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615). * The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs. * New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656). +* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`. AD Changes: From 6654aae1ec5e13f183cdf1a34f262d55fdee8ff5 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Sun, 10 Mar 2019 11:11:43 +0100 Subject: [PATCH 04/23] update NEWS.md with InstanceNorm --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 6f4758fe..bdbfb445 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,6 +9,7 @@ * New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615). * The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs. * New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656). +* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). AD Changes: From 79de829fdc272f81adda4cc725288ddb430c3255 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 27 Feb 2019 11:46:20 +0000 Subject: [PATCH 05/23] move Dense's overloads to be near its defn --- src/layers/basic.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 758aa0a9..a0399411 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -88,6 +88,14 @@ function Base.show(io::IO, l::Dense) print(io, ")") end +# Try to avoid hitting generic matmul in some simple cases +# Base's matmul is so slow that it's worth the extra conversion to hit BLAS +(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ Diagonal(in::Integer) @@ -117,10 +125,3 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end -# Try to avoid hitting generic matmul in some simple cases -# Base's matmul is so slow that it's worth the extra conversion to hit BLAS -(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - invoke(a, Tuple{AbstractArray}, x) - -(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - a(T.(x)) From fcc3ec471a5de52dec99d741f441541882066448 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 27 Feb 2019 12:04:59 +0000 Subject: [PATCH 06/23] Add MaxOut layer --- src/Flux.jl | 6 ++++-- src/layers/basic.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ test/layers/basic.jl | 27 +++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 3a862fb5..5e6c5081 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,8 +6,10 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, +export Chain, Dense, MaxOut, + RNN, LSTM, GRU, + Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, + Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a0399411..d850acd3 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -125,3 +125,47 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end + +""" + MaxOut(over) + +MaxOut is a neural network layer, which has a number of internal layers, +which all have the same input, and the max out returns the elementwise maximium +of the internal layers' outputs. + +Maxout over linear dense layers satisfies the univeral approximation theorem. + +Reference: +Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio. +2013. Maxout networks. +In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13), +Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327. +https://arxiv.org/pdf/1302.4389.pdf +""" +struct MaxOut{FS<:Tuple} + over::FS +end + +""" + MaxOut(f, n_alts, args...; kwargs...) + +Constructs a MaxOut layer over `n_alts` instances of the layer given by `f`. +All other arguements (`args` & `kwargs`) are passed to the constructor `f`. + +For example the followeExample usage +will construct a MaxOut layer over 4 dense linear layers, +each identical in structure (784 inputs, 128 outputs). +```julia + insize = 784 + outsie = 128 + MaxOut(Dense, 4, insize, outsize) +``` +""" +function MaxOut(f, n_alts, args...; kwargs...) + over = Tuple(f(args...; kwargs...) for _ in 1:n_alts) + return MaxOut(over) +end + +function (mo::MaxOut)(input::AbstractArray) + mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) +end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index b8d9efd1..846fb3ee 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -30,4 +30,31 @@ using Test, Random @test Flux.Diagonal(2)([1,2]) == [1,2] @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] end + + @testset "MaxOut" begin + # Note that the normal common usage of MaxOut is as per the docstring + # These are abnormal constructors used for testing purposes + + @testset "Constructor" begin + mo = MaxOut(() -> identity, 4) + + input = rand(40) + @test mo(input) == input + end + + @testset "simple alternatives" begin + mo = MaxOut((x -> x, x -> 2x, x -> 0.5x)) + + input = rand(40) + @test mo(input) == 2*input + end + + @testset "complex alternatives" begin + mo = MaxOut((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)) + + input = [3.0 2.0] + target = [0.5, 0.7].*input + @test mo(input) == target + end + end end From c1a33c556fe2b2a82cd9a3f1cf2a0d4e7aeb7528 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 27 Feb 2019 12:20:44 +0000 Subject: [PATCH 07/23] do things to docs --- docs/src/models/layers.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index e904ed65..65d07173 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,14 +5,16 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense +``` + +## Convolution and Pooling Layers + +These layers are used to build convolutional neural networks (CNNs). + +```@docs Conv MaxPool MeanPool -``` - -## Additional Convolution Layers - -```@docs DepthwiseConv ConvTranspose ``` @@ -28,6 +30,13 @@ GRU Flux.Recur ``` +## Hipster Layers +These are marginally more obscure layers that you probably haven't heard of. + +```@docs +MaxOut +``` + ## Activation Functions Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux. From b84a60e74e4c1192c71957bc0ac1808c0f30c1a3 Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Wed, 27 Feb 2019 15:11:24 +0000 Subject: [PATCH 08/23] Update src/layers/basic.jl Co-Authored-By: oxinabox --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index d850acd3..e456185a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -129,7 +129,7 @@ end """ MaxOut(over) -MaxOut is a neural network layer, which has a number of internal layers, +`MaxOut` is a neural network layer, which has a number of internal layers, which all have the same input, and the max out returns the elementwise maximium of the internal layers' outputs. From 838047f708ce07ae48e513a075e93e2783996a00 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 27 Feb 2019 15:19:10 +0000 Subject: [PATCH 09/23] fix docs --- docs/src/models/layers.md | 4 ++-- src/layers/basic.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 65d07173..88a1252a 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -30,8 +30,8 @@ GRU Flux.Recur ``` -## Hipster Layers -These are marginally more obscure layers that you probably haven't heard of. +## Esoteric Layers +These are marginally more obscure layers. ```@docs MaxOut diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e456185a..254ac4fb 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -152,8 +152,8 @@ end Constructs a MaxOut layer over `n_alts` instances of the layer given by `f`. All other arguements (`args` & `kwargs`) are passed to the constructor `f`. -For example the followeExample usage -will construct a MaxOut layer over 4 dense linear layers, +For example the following example which +will construct a `MaxOut` layer over 4 dense linear layers, each identical in structure (784 inputs, 128 outputs). ```julia insize = 784 From c76b9c7e2c0f622d806998b0ff4f0aeffce065bd Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 28 Feb 2019 11:35:28 +0000 Subject: [PATCH 10/23] fix docs --- docs/src/models/layers.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 88a1252a..be99d015 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -30,13 +30,25 @@ GRU Flux.Recur ``` -## Esoteric Layers -These are marginally more obscure layers. +## Other General Purpose Layers +These are marginally more obscure than the Basic Layers. +But incontrast to the layers described in the other sections are not readily grouped around a paparticular purpose (e.g. CNNs or RNNs). ```@docs MaxOut ``` +# Normalisation & Regularisation + +These layers don't affect the structure of the network but may improve training times or reduce overfitting. + +```@docs +Flux.testmode! +BatchNorm +Dropout +LayerNorm +``` + ## Activation Functions Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux. From e23c8ddd13f3514c8e53f7ab78732e52cbbee49c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 28 Feb 2019 15:49:49 +0000 Subject: [PATCH 11/23] take zero-arge closure --- test/layers/basic.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 846fb3ee..3c8c47e8 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -37,21 +37,18 @@ using Test, Random @testset "Constructor" begin mo = MaxOut(() -> identity, 4) - input = rand(40) @test mo(input) == input end @testset "simple alternatives" begin mo = MaxOut((x -> x, x -> 2x, x -> 0.5x)) - input = rand(40) @test mo(input) == 2*input end @testset "complex alternatives" begin mo = MaxOut((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)) - input = [3.0 2.0] target = [0.5, 0.7].*input @test mo(input) == target From ca68bf9bec1b6acb8728c8b782217710dd2d0e75 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 6 Mar 2019 10:22:46 -0800 Subject: [PATCH 12/23] correct casing --- src/layers/basic.jl | 20 ++++++++++---------- test/layers/basic.jl | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 254ac4fb..786cf32e 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -127,9 +127,9 @@ end """ - MaxOut(over) + Maxout(over) -`MaxOut` is a neural network layer, which has a number of internal layers, +`Maxout` is a neural network layer, which has a number of internal layers, which all have the same input, and the max out returns the elementwise maximium of the internal layers' outputs. @@ -142,30 +142,30 @@ In Proceedings of the 30th International Conference on International Conference Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327. https://arxiv.org/pdf/1302.4389.pdf """ -struct MaxOut{FS<:Tuple} +struct Maxout{FS<:Tuple} over::FS end """ - MaxOut(f, n_alts, args...; kwargs...) + Maxout(f, n_alts, args...; kwargs...) -Constructs a MaxOut layer over `n_alts` instances of the layer given by `f`. +Constructs a Maxout layer over `n_alts` instances of the layer given by `f`. All other arguements (`args` & `kwargs`) are passed to the constructor `f`. For example the following example which -will construct a `MaxOut` layer over 4 dense linear layers, +will construct a `Maxout` layer over 4 dense linear layers, each identical in structure (784 inputs, 128 outputs). ```julia insize = 784 outsie = 128 - MaxOut(Dense, 4, insize, outsize) + Maxout(Dense, 4, insize, outsize) ``` """ -function MaxOut(f, n_alts, args...; kwargs...) +function Maxout(f, n_alts, args...; kwargs...) over = Tuple(f(args...; kwargs...) for _ in 1:n_alts) - return MaxOut(over) + return Maxout(over) end -function (mo::MaxOut)(input::AbstractArray) +function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 3c8c47e8..3a3b1695 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -31,24 +31,24 @@ using Test, Random @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] end - @testset "MaxOut" begin - # Note that the normal common usage of MaxOut is as per the docstring + @testset "Maxout" begin + # Note that the normal common usage of Maxout is as per the docstring # These are abnormal constructors used for testing purposes @testset "Constructor" begin - mo = MaxOut(() -> identity, 4) + mo = Maxout(() -> identity, 4) input = rand(40) @test mo(input) == input end @testset "simple alternatives" begin - mo = MaxOut((x -> x, x -> 2x, x -> 0.5x)) + mo = Maxout((x -> x, x -> 2x, x -> 0.5x)) input = rand(40) @test mo(input) == 2*input end @testset "complex alternatives" begin - mo = MaxOut((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)) + mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)) input = [3.0 2.0] target = [0.5, 0.7].*input @test mo(input) == target From 2bc4b8d1a42bbedf559894c5b59e3b43b0fde79f Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Thu, 7 Mar 2019 03:44:13 -0800 Subject: [PATCH 13/23] Update docs/src/models/layers.md Co-Authored-By: oxinabox --- docs/src/models/layers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index be99d015..36274c6d 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -35,7 +35,7 @@ These are marginally more obscure than the Basic Layers. But incontrast to the layers described in the other sections are not readily grouped around a paparticular purpose (e.g. CNNs or RNNs). ```@docs -MaxOut +Maxout ``` # Normalisation & Regularisation From f222555deb7fbc1bcf005158dd89a0e2ed575d0d Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Thu, 7 Mar 2019 07:44:29 -0800 Subject: [PATCH 14/23] Update src/Flux.jl Co-Authored-By: oxinabox --- src/Flux.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index 5e6c5081..ff178450 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,7 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, MaxOut, +export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, From 025d9b678dc491a8d8d02b2e7451c249c8d4e24d Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Thu, 7 Mar 2019 07:44:46 -0800 Subject: [PATCH 15/23] Update docs/src/models/layers.md Co-Authored-By: oxinabox --- docs/src/models/layers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 36274c6d..5d5bc5d9 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -32,7 +32,7 @@ Flux.Recur ## Other General Purpose Layers These are marginally more obscure than the Basic Layers. -But incontrast to the layers described in the other sections are not readily grouped around a paparticular purpose (e.g. CNNs or RNNs). +But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs). ```@docs Maxout From 7d247ea25bccfd196be50b1dca5d11ad1b8a96bc Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 11 Mar 2019 18:40:29 -0300 Subject: [PATCH 16/23] update docstring --- src/layers/basic.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 786cf32e..1e2ab891 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -130,7 +130,7 @@ end Maxout(over) `Maxout` is a neural network layer, which has a number of internal layers, -which all have the same input, and the max out returns the elementwise maximium +which all have the same input, and the maxout returns the elementwise maximium of the internal layers' outputs. Maxout over linear dense layers satisfies the univeral approximation theorem. @@ -150,15 +150,16 @@ end Maxout(f, n_alts, args...; kwargs...) Constructs a Maxout layer over `n_alts` instances of the layer given by `f`. -All other arguements (`args` & `kwargs`) are passed to the constructor `f`. +The function takes no arguement and should return some callable layer. +Conventionally this is a linear dense layer. For example the following example which -will construct a `Maxout` layer over 4 dense linear layers, +will construct a `Maxout` layer over 4 internal dense linear layers, each identical in structure (784 inputs, 128 outputs). ```julia insize = 784 outsie = 128 - Maxout(Dense, 4, insize, outsize) + Maxout(()->Dense(insize, outsize), 4) ``` """ function Maxout(f, n_alts, args...; kwargs...) From 401d3da8846a898082244ccdf7a385b33effef6d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 21 Mar 2019 17:04:52 +0000 Subject: [PATCH 17/23] no arg closures --- src/layers/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1e2ab891..b39a0de2 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -147,7 +147,7 @@ struct Maxout{FS<:Tuple} end """ - Maxout(f, n_alts, args...; kwargs...) + Maxout(f, n_alts) Constructs a Maxout layer over `n_alts` instances of the layer given by `f`. The function takes no arguement and should return some callable layer. @@ -162,8 +162,8 @@ each identical in structure (784 inputs, 128 outputs). Maxout(()->Dense(insize, outsize), 4) ``` """ -function Maxout(f, n_alts, args...; kwargs...) - over = Tuple(f(args...; kwargs...) for _ in 1:n_alts) +function Maxout(f, n_alts) + over = Tuple(f() for _ in 1:n_alts) return Maxout(over) end From df509ce9f0504e475e5e59deae9de571f029a75e Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 6 Feb 2019 15:01:01 +0100 Subject: [PATCH 18/23] Adapt to the new CUDAdrv.CuPtr pointer type. --- src/cuda/cuda.jl | 1 + src/cuda/cudnn.jl | 36 ++++++++++++++++++------------------ src/cuda/curnn.jl | 34 +++++++++++++++++----------------- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 070c9228..762b9b2e 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -1,5 +1,6 @@ module CUDA +import CUDAdrv: CuPtr, CU_NULL using ..CuArrays using Pkg.TOML diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 8bd8135e..8671d166 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0) @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s) states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0? desc = DropoutDesc(d[], states) - @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong), + @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong), desc,handle(),ρ,states,length(states),seed) finalizer(desc) do x @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) @@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray mean = zeros(CuArray{T}, dims...) ivar = ones(CuArray{T}, dims...) else - mean = C_NULL - ivar = C_NULL + mean = CU_NULL + ivar = CU_NULL end @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, (cudnnHandle_t,cudnnBatchNormMode_t, Ptr{T}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}), + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, CuPtr{T}, + Cdouble, CuPtr{T}, CuPtr{T}, + Cdouble, CuPtr{T}, CuPtr{T}), handle(), BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, @@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, Ptr{T}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{T}, - Ptr{T}, Ptr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, CuPtr{T}, + CuPtr{T}, CuPtr{T}, Cdouble), handle(), BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), @@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, mean, ivar = cache.mean, cache.ivar info("mean and ivar are fetched from the cache") else - mean, ivar = C_NULL, C_NULL + mean, ivar = CU_NULL, CU_NULL end if eps < BATCHNORM_MIN_EPS @@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, (cudnnHandle_t,cudnnBatchNormMode_t, Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}), + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T}, + Cdouble, CuPtr{T}, CuPtr{T}), handle(), BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), Ref(T(dalpha)), Ref(T(dbeta)), diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 210ddd7c..09f6d43c 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd if reserve == nothing @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t), + Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + CuPtr{Nothing}, Csize_t), handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, length(workspace)) else @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), + Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve)) @@ -121,7 +121,7 @@ end xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] -hDesc(h::Nothing) = C_NULL, C_NULL +hDesc(h::Nothing) = C_NULL, CU_NULL hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) function hDesc(h::CuArray) TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h @@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, - Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), + Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, + CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) end @@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d workspace, reserve) where T @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength - Ptr{Ptr{Nothing}}, Ptr{T}, #x - Ptr{Nothing}, Ptr{T}, #hx - Ptr{Ptr{Nothing}}, Ptr{T}, #y - Ptr{Nothing}, Csize_t, #ws - Ptr{Nothing}, Ptr{T}, #dw - Ptr{Nothing}, Csize_t), #rs + Ptr{Ptr{Nothing}}, CuPtr{T}, #x + Ptr{Nothing}, CuPtr{T}, #hx + Ptr{Ptr{Nothing}}, CuPtr{T}, #y + CuPtr{Nothing}, Csize_t, #ws + Ptr{Nothing}, CuPtr{T}, #dw + CuPtr{Nothing}, Csize_t), #rs handle(), rnn, seqlen, xd, x, hd, h, yd, y, workspace, length(workspace), dwd, dw, reserve, length(reserve)) end From 959dd247bfebbc829f043c19ada70de99df7df31 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 8 Feb 2019 16:29:12 +0100 Subject: [PATCH 19/23] Import CUDAdrv stuff through CuArrays. --- src/cuda/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 762b9b2e..28b6db09 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -1,7 +1,7 @@ module CUDA -import CUDAdrv: CuPtr, CU_NULL using ..CuArrays +import CuArrays.CUDAdrv: CuPtr, CU_NULL using Pkg.TOML function version_check() From bc068613209b56e686916b14d3c5c60596e5128f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 21 Mar 2019 21:44:10 +0530 Subject: [PATCH 20/23] fix indirect import --- src/cuda/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 28b6db09..9cbc7e76 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -1,7 +1,7 @@ module CUDA using ..CuArrays -import CuArrays.CUDAdrv: CuPtr, CU_NULL +import ..CuArrays.CUDAdrv: CuPtr, CU_NULL using Pkg.TOML function version_check() From 0734eeb50eba8c163fc4e0b309170fbc49987ff8 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 22 Mar 2019 14:12:04 +0100 Subject: [PATCH 21/23] Check CuArrays major version. --- src/cuda/cuda.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 9cbc7e76..89caf0d3 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -5,14 +5,14 @@ import ..CuArrays.CUDAdrv: CuPtr, CU_NULL using Pkg.TOML function version_check() - minor_version = 9 + major_version = 1 project = joinpath(dirname(pathof(CuArrays)), "../Project.toml") project = TOML.parse(String(read(project))) version = VersionNumber(get(project, "version", "0.0.0")) - if !(version.major == 0 && version.minor == minor_version) + if version.major != major_version @warn """ - Flux is only supported with CuArrays v0.$minor_version. - Try running `] pin CuArrays@0.$minor_version`. + Flux is only supported with CuArrays v$major_version.x. + Try running `] pin CuArrays@$major_version`. """ end end From db7f1a52dbecca1021448b99951d1f7ec5a88ac2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 22 Mar 2019 21:51:04 +0530 Subject: [PATCH 22/23] update nnlib --- Manifest.toml | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 91ed508a..46e56ad3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -84,7 +84,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "0.0.10" [[Distributed]] -deps = ["Random", "Serialization", "Sockets"] +deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FixedPointNumbers]] @@ -100,14 +100,14 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.3" [[InteractiveUtils]] -deps = ["Markdown"] +deps = ["LinearAlgebra", "Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[Juno]] deps = ["Base64", "Logging", "Media", "Profile", "Test"] -git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8" +git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.5.4" +version = "0.7.0" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -149,11 +149,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd" -repo-rev = "master" -repo-url = "https://github.com/FluxML/NNlib.jl.git" +git-tree-sha1 = "d07ac0bfd3c71c3a29bc9c22becbba19227bbeb5" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.4.3+" +version = "0.5.0" [[NaNMath]] deps = ["Compat"] @@ -256,9 +254,9 @@ version = "0.1.0" [[TranscodingStreams]] deps = ["Pkg", "Random", "Test"] -git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c" +git-tree-sha1 = "f42956022d8084539f1d7219f632542b0ea686ce" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.1" +version = "0.9.3" [[URIParser]] deps = ["Test", "Unicode"] @@ -267,7 +265,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random", "SHA"] +deps = ["Random"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] From 9249f64e1d86b66859fad945f0b0b8b81f0a8f7a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 22 Mar 2019 23:35:29 +0530 Subject: [PATCH 23/23] add Tracker to REQUIRE --- REQUIRE | 1 + 1 file changed, 1 insertion(+) diff --git a/REQUIRE b/REQUIRE index 455a1c15..3e8e9066 100644 --- a/REQUIRE +++ b/REQUIRE @@ -10,3 +10,4 @@ ZipFile AbstractTrees Reexport StatsBase +Tracker