use `expand`
This commit is contained in:
parent
e6be639436
commit
1e0fd07b09
|
@ -5,7 +5,7 @@ module Flux
|
|||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using NNlib: conv
|
||||
|
||||
@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2)))
|
||||
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
@ -28,7 +28,7 @@ end
|
|||
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
||||
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
|
@ -55,7 +55,7 @@ end
|
|||
"""
|
||||
MaxPool(k)
|
||||
|
||||
Maxpooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
|
@ -63,25 +63,21 @@ struct MaxPool{N}
|
|||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
MaxPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
|
||||
end
|
||||
|
||||
function MaxPool{N}(k::Int; pad = 0, stride = k) where N
|
||||
k_ = Tuple(repeat([k, ], N))
|
||||
MaxPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_))
|
||||
end
|
||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
|
||||
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MaxPool)
|
||||
print(io, "MaxPool(", m.k, ", ", m.pad, ", ", m.stride, ")")
|
||||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
MeanPool(k)
|
||||
|
||||
Meanpooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
|
@ -89,16 +85,13 @@ struct MeanPool{N}
|
|||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
MeanPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
|
||||
end
|
||||
|
||||
function MeanPool{N}(k::Int; pad = 0, stride = k) where N
|
||||
k_ = Tuple(repeat([k, ], N))
|
||||
MeanPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_))
|
||||
end
|
||||
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
|
||||
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MeanPool)
|
||||
print(io, "MeanPool(", m.k, ", ", m.pad, ", ", m.stride, ")")
|
||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
|
|
@ -1,21 +1,13 @@
|
|||
using Test
|
||||
using Flux: Chain, Conv, MaxPool, MeanPool
|
||||
using Flux: Chain, Conv, MaxPool, MeanPool, maxpool, meanpool
|
||||
using Base.conv
|
||||
|
||||
@testset "pooling" begin
|
||||
x = randn(10, 10, 3, 2)
|
||||
mp = MaxPool((2, 2))
|
||||
|
||||
@testset "maxpooling" begin
|
||||
@test MaxPool{2}(2) == mp
|
||||
@test MaxPool{2}(2; pad=1, stride=3) == MaxPool((2, 2); pad=(1, 1), stride=(3, 3))
|
||||
end
|
||||
|
||||
@test mp(x) == maxpool(x, (2,2))
|
||||
mp = MeanPool((2, 2))
|
||||
|
||||
@testset "meanpooling" begin
|
||||
@test MeanPool{2}(2) == mp
|
||||
@test MeanPool{2}(2; pad=1, stride=3) == MeanPool((2, 2); pad=(1, 1), stride=(3, 3))
|
||||
end
|
||||
@test mp(x) == meanpool(x, (2,2))
|
||||
end
|
||||
|
||||
@testset "cnn" begin
|
||||
|
|
Loading…
Reference in New Issue