Merge pull request #339 from yuehhua/master

Add Maxpool and Meanpool for convention.
This commit is contained in:
Mike J Innes 2018-09-04 14:52:10 +01:00 committed by GitHub
commit 2005247d5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 3 deletions

View File

@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks.
Chain Chain
Dense Dense
Conv Conv
MaxPool
MeanPool
``` ```
## Recurrent Layers ## Recurrent Layers

View File

@ -5,7 +5,7 @@ module Flux
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu

View File

@ -1,6 +1,6 @@
using NNlib: conv using NNlib: conv
@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2))) @generated sub2(::Val{N}) where N = :(Val($(N-2)))
expand(N, i::Tuple) = i expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N) expand(N, i::Integer) = ntuple(_ -> i, N)
@ -28,7 +28,7 @@ end
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} = stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N = stride = 1, pad = 0, dilation = 1) where N =
@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv)
l.σ == identity || print(io, ", ", l.σ) l.σ == identity || print(io, ", ", l.σ)
print(io, ")") print(io, ")")
end end
"""
MaxPool(k)
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
"""
struct MaxPool{N}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end
"""
MeanPool(k)
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
"""
struct MeanPool{N}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
end
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end

23
test/layers/conv.jl Normal file
View File

@ -0,0 +1,23 @@
using Flux, Test
using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2))
end
@testset "CNN" begin
r = zeros(28, 28, 1, 5)
m = Chain(
Conv((2, 2), 1=>16, relu),
MaxPool((2,2)),
Conv((2, 2), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax)
@test size(m(r)) == (10, 5)
end

View File

@ -28,6 +28,7 @@ include("onehot.jl")
include("tracker.jl") include("tracker.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/conv.jl")
include("optimise.jl") include("optimise.jl")
include("data.jl") include("data.jl")