Add SamePad for pooling layers
This commit is contained in:
parent
fc123d6279
commit
411ce5dbd8
|
@ -308,6 +308,8 @@ end
|
|||
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride
|
||||
"""
|
||||
struct MaxPool{N,M}
|
||||
k::NTuple{N,Int}
|
||||
|
@ -317,8 +319,7 @@ end
|
|||
|
||||
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||
stride = expand(Val(N), stride)
|
||||
pad = expand(Val(2*N), pad)
|
||||
|
||||
pad = calc_padding(pad, k, 1, stride)
|
||||
return MaxPool(k, pad, stride)
|
||||
end
|
||||
|
||||
|
@ -337,6 +338,8 @@ end
|
|||
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
|
||||
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride
|
||||
"""
|
||||
struct MeanPool{N,M}
|
||||
k::NTuple{N,Int}
|
||||
|
@ -346,7 +349,7 @@ end
|
|||
|
||||
function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||
stride = expand(Val(N), stride)
|
||||
pad = expand(Val(2*N), pad)
|
||||
pad = calc_padding(pad, k, 1, stride)
|
||||
return MeanPool(k, pad, stride)
|
||||
end
|
||||
|
||||
|
|
|
@ -114,8 +114,15 @@ end
|
|||
stride = 3
|
||||
l = ltype(k, 1=>1, pad=SamePad(), stride = stride)
|
||||
if ltype == ConvTranspose
|
||||
@test size(l(data))[1:end-2] == stride .* size(data)[1:end-2] .- stride .- 1
|
||||
@test size(l(data))[1:end-2] == stride .* size(data)[1:end-2] .- stride .+ 1
|
||||
else
|
||||
@test size(l(data))[1:end-2] == ceil.(Int, size(data)[1:end-2] ./ stride)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), k in ( (1,), (2,), (3,), (4,5), (6,7,8))
|
||||
data = ones(Float32, (k .+ 3)..., 1,1)
|
||||
|
||||
l = ltype(k, pad=SamePad())
|
||||
@test size(l(data))[1:end-2] == ceil.(Int, size(data)[1:end-2] ./ k)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue