diff --git a/Manifest.toml b/Manifest.toml index 91ed508a..46e56ad3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -84,7 +84,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "0.0.10" [[Distributed]] -deps = ["Random", "Serialization", "Sockets"] +deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FixedPointNumbers]] @@ -100,14 +100,14 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.3" [[InteractiveUtils]] -deps = ["Markdown"] +deps = ["LinearAlgebra", "Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[Juno]] deps = ["Base64", "Logging", "Media", "Profile", "Test"] -git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8" +git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.5.4" +version = "0.7.0" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -149,11 +149,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd" -repo-rev = "master" -repo-url = "https://github.com/FluxML/NNlib.jl.git" +git-tree-sha1 = "d07ac0bfd3c71c3a29bc9c22becbba19227bbeb5" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.4.3+" +version = "0.5.0" [[NaNMath]] deps = ["Compat"] @@ -256,9 +254,9 @@ version = "0.1.0" [[TranscodingStreams]] deps = ["Pkg", "Random", "Test"] -git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c" +git-tree-sha1 = "f42956022d8084539f1d7219f632542b0ea686ce" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.1" +version = "0.9.3" [[URIParser]] deps = ["Test", "Unicode"] @@ -267,7 +265,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random", "SHA"] +deps = ["Random"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] diff --git a/NEWS.md b/NEWS.md index 6f4758fe..4cf755e7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,6 +9,8 @@ * New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615). * The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs. * New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656). +* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`. +* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). AD Changes: diff --git a/Project.toml b/Project.toml index ebb26701..08b15332 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/REQUIRE b/REQUIRE index 455a1c15..3e8e9066 100644 --- a/REQUIRE +++ b/REQUIRE @@ -10,3 +10,4 @@ ZipFile AbstractTrees Reexport StatsBase +Tracker diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index e904ed65..5d5bc5d9 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,14 +5,16 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense +``` + +## Convolution and Pooling Layers + +These layers are used to build convolutional neural networks (CNNs). + +```@docs Conv MaxPool MeanPool -``` - -## Additional Convolution Layers - -```@docs DepthwiseConv ConvTranspose ``` @@ -28,6 +30,25 @@ GRU Flux.Recur ``` +## Other General Purpose Layers +These are marginally more obscure than the Basic Layers. +But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs). + +```@docs +Maxout +``` + +# Normalisation & Regularisation + +These layers don't affect the structure of the network but may improve training times or reduce overfitting. + +```@docs +Flux.testmode! +BatchNorm +Dropout +LayerNorm +``` + ## Activation Functions Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux. diff --git a/src/Flux.jl b/src/Flux.jl index 3a862fb5..ff178450 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,8 +6,10 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, +export Chain, Dense, Maxout, + RNN, LSTM, GRU, + Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, + Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 070c9228..89caf0d3 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -1,17 +1,18 @@ module CUDA using ..CuArrays +import ..CuArrays.CUDAdrv: CuPtr, CU_NULL using Pkg.TOML function version_check() - minor_version = 9 + major_version = 1 project = joinpath(dirname(pathof(CuArrays)), "../Project.toml") project = TOML.parse(String(read(project))) version = VersionNumber(get(project, "version", "0.0.0")) - if !(version.major == 0 && version.minor == minor_version) + if version.major != major_version @warn """ - Flux is only supported with CuArrays v0.$minor_version. - Try running `] pin CuArrays@0.$minor_version`. + Flux is only supported with CuArrays v$major_version.x. + Try running `] pin CuArrays@$major_version`. """ end end diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 8bd8135e..8671d166 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0) @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s) states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0? desc = DropoutDesc(d[], states) - @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong), + @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong), desc,handle(),ρ,states,length(states),seed) finalizer(desc) do x @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) @@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray mean = zeros(CuArray{T}, dims...) ivar = ones(CuArray{T}, dims...) else - mean = C_NULL - ivar = C_NULL + mean = CU_NULL + ivar = CU_NULL end @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, (cudnnHandle_t,cudnnBatchNormMode_t, Ptr{T}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}), + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, CuPtr{T}, + Cdouble, CuPtr{T}, CuPtr{T}, + Cdouble, CuPtr{T}, CuPtr{T}), handle(), BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, @@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, Ptr{T}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{T}, - Ptr{T}, Ptr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, CuPtr{T}, + CuPtr{T}, CuPtr{T}, Cdouble), handle(), BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), @@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, mean, ivar = cache.mean, cache.ivar info("mean and ivar are fetched from the cache") else - mean, ivar = C_NULL, C_NULL + mean, ivar = CU_NULL, CU_NULL end if eps < BATCHNORM_MIN_EPS @@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, (cudnnHandle_t,cudnnBatchNormMode_t, Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}), + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T}, + Cdouble, CuPtr{T}, CuPtr{T}), handle(), BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), Ref(T(dalpha)), Ref(T(dbeta)), diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 210ddd7c..09f6d43c 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd if reserve == nothing @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t), + Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, + CuPtr{Nothing}, Csize_t), handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, length(workspace)) else @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), + Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve)) @@ -121,7 +121,7 @@ end xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] -hDesc(h::Nothing) = C_NULL, C_NULL +hDesc(h::Nothing) = C_NULL, CU_NULL hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) function hDesc(h::CuArray) TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h @@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, - Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), + Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, + CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, + CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) end @@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d workspace, reserve) where T @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength - Ptr{Ptr{Nothing}}, Ptr{T}, #x - Ptr{Nothing}, Ptr{T}, #hx - Ptr{Ptr{Nothing}}, Ptr{T}, #y - Ptr{Nothing}, Csize_t, #ws - Ptr{Nothing}, Ptr{T}, #dw - Ptr{Nothing}, Csize_t), #rs + Ptr{Ptr{Nothing}}, CuPtr{T}, #x + Ptr{Nothing}, CuPtr{T}, #hx + Ptr{Ptr{Nothing}}, CuPtr{T}, #y + CuPtr{Nothing}, Csize_t, #ws + Ptr{Nothing}, CuPtr{T}, #dw + CuPtr{Nothing}, Csize_t), #rs handle(), rnn, seqlen, xd, x, hd, h, yd, y, workspace, length(workspace), dwd, dw, reserve, length(reserve)) end diff --git a/src/data/Data.jl b/src/data/Data.jl index ab78f416..d7cd0303 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -39,4 +39,7 @@ include("tree.jl") include("sentiment.jl") using .Sentiment +include("iris.jl") +export Iris + end diff --git a/src/data/iris.jl b/src/data/iris.jl new file mode 100644 index 00000000..c432f847 --- /dev/null +++ b/src/data/iris.jl @@ -0,0 +1,88 @@ + +""" + + Iris + +Fisher's classic iris dataset. + +Measurements from 3 different species of iris: setosa, versicolor and +virginica. There are 50 examples of each species. + +There are 4 measurements for each example: sepal length, sepal width, petal +length and petal width. The measurements are in centimeters. + +The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris). + +""" +module Iris + +using DelimitedFiles +using ..Data: deps, download_and_verify + +const cache_prefix = "" + +# Uncomment if the iris.data file is cached to cache.julialang.org. +# const cache_prefix = "https://cache.julialang.org/" + +function load() + isfile(deps("iris.data")) && return + + @info "Downloading iris dataset." + download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data", + deps("iris.data"), + "6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0") +end + +""" + + labels() + +Get the labels of the iris dataset, a 150 element array of strings listing the +species of each example. + +```jldoctest +julia> labels = Flux.Data.Iris.labels(); + +julia> summary(labels) +"150-element Array{String,1}" + +julia> labels[1] +"Iris-setosa" +``` +""" +function labels() + load() + iris = readdlm(deps("iris.data"), ',') + Vector{String}(iris[1:end, end]) +end + +""" + + features() + +Get the features of the iris dataset. This is a 4x150 matrix of Float64 +elements. It has a row for each feature (sepal length, sepal width, +petal length, petal width) and a column for each example. + +```jldoctest +julia> features = Flux.Data.Iris.features(); + +julia> summary(features) +"4×150 Array{Float64,2}" + +julia> features[:, 1] +4-element Array{Float64,1}: + 5.1 + 3.5 + 1.4 + 0.2 +``` +""" +function features() + load() + iris = readdlm(deps("iris.data"), ',') + Matrix{Float64}(iris[1:end, 1:4]') +end +end + + diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 758aa0a9..b39a0de2 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -88,6 +88,14 @@ function Base.show(io::IO, l::Dense) print(io, ")") end +# Try to avoid hitting generic matmul in some simple cases +# Base's matmul is so slow that it's worth the extra conversion to hit BLAS +(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ Diagonal(in::Integer) @@ -117,10 +125,48 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end -# Try to avoid hitting generic matmul in some simple cases -# Base's matmul is so slow that it's worth the extra conversion to hit BLAS -(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - invoke(a, Tuple{AbstractArray}, x) -(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = - a(T.(x)) +""" + Maxout(over) + +`Maxout` is a neural network layer, which has a number of internal layers, +which all have the same input, and the maxout returns the elementwise maximium +of the internal layers' outputs. + +Maxout over linear dense layers satisfies the univeral approximation theorem. + +Reference: +Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio. +2013. Maxout networks. +In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13), +Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327. +https://arxiv.org/pdf/1302.4389.pdf +""" +struct Maxout{FS<:Tuple} + over::FS +end + +""" + Maxout(f, n_alts) + +Constructs a Maxout layer over `n_alts` instances of the layer given by `f`. +The function takes no arguement and should return some callable layer. +Conventionally this is a linear dense layer. + +For example the following example which +will construct a `Maxout` layer over 4 internal dense linear layers, +each identical in structure (784 inputs, 128 outputs). +```julia + insize = 784 + outsie = 128 + Maxout(()->Dense(insize, outsize), 4) +``` +""" +function Maxout(f, n_alts) + over = Tuple(f() for _ in 1:n_alts) + return Maxout(over) +end + +function (mo::Maxout)(input::AbstractArray) + mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) +end diff --git a/test/data.jl b/test/data.jl index a73d1ec3..6b777873 100644 --- a/test/data.jl +++ b/test/data.jl @@ -14,3 +14,9 @@ using Test @test FashionMNIST.labels() isa Vector{Int64} @test Data.Sentiment.train() isa Vector{Data.Tree{Any}} + +@test Iris.features() isa Matrix +@test size(Iris.features()) == (4,150) + +@test Iris.labels() isa Vector{String} +@test size(Iris.labels()) == (150,) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index b8d9efd1..3a3b1695 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -30,4 +30,28 @@ using Test, Random @test Flux.Diagonal(2)([1,2]) == [1,2] @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] end + + @testset "Maxout" begin + # Note that the normal common usage of Maxout is as per the docstring + # These are abnormal constructors used for testing purposes + + @testset "Constructor" begin + mo = Maxout(() -> identity, 4) + input = rand(40) + @test mo(input) == input + end + + @testset "simple alternatives" begin + mo = Maxout((x -> x, x -> 2x, x -> 0.5x)) + input = rand(40) + @test mo(input) == 2*input + end + + @testset "complex alternatives" begin + mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)) + input = [3.0 2.0] + target = [0.5, 0.7].*input + @test mo(input) == target + end + end end