Merge pull request #339 from yuehhua/master
Add Maxpool and Meanpool for convention.
This commit is contained in:
commit
2005247d5a
@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks.
|
|||||||
Chain
|
Chain
|
||||||
Dense
|
Dense
|
||||||
Conv
|
Conv
|
||||||
|
MaxPool
|
||||||
|
MeanPool
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recurrent Layers
|
## Recurrent Layers
|
||||||
|
@ -5,7 +5,7 @@ module Flux
|
|||||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||||
Dropout, LayerNorm, BatchNorm,
|
Dropout, LayerNorm, BatchNorm,
|
||||||
params, mapleaves, cpu, gpu
|
params, mapleaves, cpu, gpu
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
|
||||||
@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2)))
|
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||||
|
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
@ -28,7 +28,7 @@ end
|
|||||||
|
|
||||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||||
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||||
|
|
||||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||||
stride = 1, pad = 0, dilation = 1) where N =
|
stride = 1, pad = 0, dilation = 1) where N =
|
||||||
@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv)
|
|||||||
l.σ == identity || print(io, ", ", l.σ)
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
MaxPool(k)
|
||||||
|
|
||||||
|
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||||
|
|
||||||
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
|
"""
|
||||||
|
struct MaxPool{N}
|
||||||
|
k::NTuple{N,Int}
|
||||||
|
pad::NTuple{N,Int}
|
||||||
|
stride::NTuple{N,Int}
|
||||||
|
end
|
||||||
|
|
||||||
|
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||||
|
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||||
|
|
||||||
|
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||||
|
|
||||||
|
function Base.show(io::IO, m::MaxPool)
|
||||||
|
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
MeanPool(k)
|
||||||
|
|
||||||
|
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||||
|
|
||||||
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
|
"""
|
||||||
|
struct MeanPool{N}
|
||||||
|
k::NTuple{N,Int}
|
||||||
|
pad::NTuple{N,Int}
|
||||||
|
stride::NTuple{N,Int}
|
||||||
|
end
|
||||||
|
|
||||||
|
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||||
|
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||||
|
|
||||||
|
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||||
|
|
||||||
|
function Base.show(io::IO, m::MeanPool)
|
||||||
|
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
|
end
|
||||||
|
23
test/layers/conv.jl
Normal file
23
test/layers/conv.jl
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
using Flux, Test
|
||||||
|
using Flux: maxpool, meanpool
|
||||||
|
|
||||||
|
@testset "Pooling" begin
|
||||||
|
x = randn(10, 10, 3, 2)
|
||||||
|
mp = MaxPool((2, 2))
|
||||||
|
@test mp(x) == maxpool(x, (2,2))
|
||||||
|
mp = MeanPool((2, 2))
|
||||||
|
@test mp(x) == meanpool(x, (2,2))
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "CNN" begin
|
||||||
|
r = zeros(28, 28, 1, 5)
|
||||||
|
m = Chain(
|
||||||
|
Conv((2, 2), 1=>16, relu),
|
||||||
|
MaxPool((2,2)),
|
||||||
|
Conv((2, 2), 16=>8, relu),
|
||||||
|
MaxPool((2,2)),
|
||||||
|
x -> reshape(x, :, size(x, 4)),
|
||||||
|
Dense(288, 10), softmax)
|
||||||
|
|
||||||
|
@test size(m(r)) == (10, 5)
|
||||||
|
end
|
@ -28,6 +28,7 @@ include("onehot.jl")
|
|||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
|
include("layers/conv.jl")
|
||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user