diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index c2056bb4..4bbb2ba0 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks. Chain Dense Conv +MaxPool +MeanPool ``` ## Recurrent Layers diff --git a/src/Flux.jl b/src/Flux.jl index 614eeaf7..8c959fec 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -5,7 +5,7 @@ module Flux using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, +export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, Dropout, LayerNorm, BatchNorm, params, mapleaves, cpu, gpu diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 78509c84..dbf8ccf9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,6 +1,6 @@ using NNlib: conv -@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2))) +@generated sub2(::Val{N}) where N = :(Val($(N-2))) expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -28,7 +28,7 @@ end Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} = - Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) + Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0, dilation = 1) where N = @@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end + + +""" + MaxPool(k) + +Max pooling layer. `k` stands for the size of the window for each dimension of the input. + +Takes the keyword arguments `pad` and `stride`. +""" +struct MaxPool{N} + k::NTuple{N,Int} + pad::NTuple{N,Int} + stride::NTuple{N,Int} +end + +MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = + MaxPool(k, expand(Val(N), pad), expand(Val(N), stride)) + +(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::MaxPool) + print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") +end + +""" + MeanPool(k) + +Mean pooling layer. `k` stands for the size of the window for each dimension of the input. + +Takes the keyword arguments `pad` and `stride`. +""" +struct MeanPool{N} + k::NTuple{N,Int} + pad::NTuple{N,Int} + stride::NTuple{N,Int} +end + +MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = + MeanPool(k, expand(Val(N), pad), expand(Val(N), stride)) + +(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) + +function Base.show(io::IO, m::MeanPool) + print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") +end diff --git a/test/layers/conv.jl b/test/layers/conv.jl new file mode 100644 index 00000000..5928bd75 --- /dev/null +++ b/test/layers/conv.jl @@ -0,0 +1,23 @@ +using Flux, Test +using Flux: maxpool, meanpool + +@testset "Pooling" begin + x = randn(10, 10, 3, 2) + mp = MaxPool((2, 2)) + @test mp(x) == maxpool(x, (2,2)) + mp = MeanPool((2, 2)) + @test mp(x) == meanpool(x, (2,2)) +end + +@testset "CNN" begin + r = zeros(28, 28, 1, 5) + m = Chain( + Conv((2, 2), 1=>16, relu), + MaxPool((2,2)), + Conv((2, 2), 16=>8, relu), + MaxPool((2,2)), + x -> reshape(x, :, size(x, 4)), + Dense(288, 10), softmax) + + @test size(m(r)) == (10, 5) +end diff --git a/test/runtests.jl b/test/runtests.jl index fd48e547..70d929bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ include("onehot.jl") include("tracker.jl") include("layers/normalisation.jl") include("layers/stateless.jl") +include("layers/conv.jl") include("optimise.jl") include("data.jl")