Add new constructors and test
This commit is contained in:
parent
5b37319289
commit
634d34686e
|
@ -6,8 +6,8 @@ These core layers form the foundation of almost all neural networks.
|
|||
Chain
|
||||
Dense
|
||||
Conv
|
||||
Maxpool
|
||||
Meanpool
|
||||
MaxPool
|
||||
MeanPool
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
|
|
@ -53,42 +53,52 @@ end
|
|||
|
||||
|
||||
"""
|
||||
Maxpool(k)
|
||||
MaxPool(k)
|
||||
|
||||
Maxpooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct Maxpool{N}
|
||||
struct MaxPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
Maxpool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
|
||||
MaxPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
|
||||
end
|
||||
|
||||
(m::Maxpool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
function MaxPool{N}(k::Int; pad = 0, stride = k) where N
|
||||
k_ = Tuple(repeat([k, ], N))
|
||||
MaxPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_))
|
||||
end
|
||||
|
||||
function Base.show(io::IO, m::Maxpool)
|
||||
print(io, "Maxpool(", m.k, ", ", m.pad, ", ", m.stride, ")")
|
||||
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MaxPool)
|
||||
print(io, "MaxPool(", m.k, ", ", m.pad, ", ", m.stride, ")")
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
Meanpool(k)
|
||||
MeanPool(k)
|
||||
|
||||
Meanpooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct Meanpool{N}
|
||||
struct MeanPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
Meanpool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
|
||||
MeanPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
|
||||
end
|
||||
|
||||
(m::Meanpool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::Meanpool)
|
||||
print(io, "Meanpool(", m.k, ", ", m.pad, ", ", m.stride, ")")
|
||||
function MeanPool{N}(k::Int; pad = 0, stride = k) where N
|
||||
k_ = Tuple(repeat([k, ], N))
|
||||
MeanPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_))
|
||||
end
|
||||
|
||||
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MeanPool)
|
||||
print(io, "MeanPool(", m.k, ", ", m.pad, ", ", m.stride, ")")
|
||||
end
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
using Test
|
||||
using Flux: Chain, Conv, MaxPool, MeanPool
|
||||
using Base.conv
|
||||
|
||||
@testset "pooling" begin
|
||||
mp = MaxPool((2, 2))
|
||||
|
||||
@testset "maxpooling" begin
|
||||
@test MaxPool{2}(2) == mp
|
||||
@test MaxPool{2}(2; pad=1, stride=3) == MaxPool((2, 2); pad=(1, 1), stride=(3, 3))
|
||||
end
|
||||
|
||||
mp = MeanPool((2, 2))
|
||||
|
||||
@testset "meanpooling" begin
|
||||
@test MeanPool{2}(2) == mp
|
||||
@test MeanPool{2}(2; pad=1, stride=3) == MeanPool((2, 2); pad=(1, 1), stride=(3, 3))
|
||||
end
|
||||
end
|
||||
|
||||
@testset "cnn" begin
|
||||
r = zeros(28, 28)
|
||||
m = Chain(
|
||||
Conv((2, 2), 1=>16, relu),
|
||||
MaxPool{2}(2),
|
||||
Conv((2, 2), 16=>8, relu),
|
||||
MaxPool{2}(2),
|
||||
x -> reshape(x, :, size(x, 4)),
|
||||
Dense(288, 10), softmax)
|
||||
|
||||
@testset "inference" begin
|
||||
@test size(m(r)) == (10, )
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue