use expand

This commit is contained in:
Mike J Innes 2018-09-04 14:30:02 +01:00
parent e6be639436
commit 1e0fd07b09
3 changed files with 15 additions and 30 deletions

View File

@ -5,7 +5,7 @@ module Flux
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu

View File

@ -1,6 +1,6 @@
using NNlib: conv using NNlib: conv
@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2))) @generated sub2(::Val{N}) where N = :(Val($(N-2)))
expand(N, i::Tuple) = i expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N) expand(N, i::Integer) = ntuple(_ -> i, N)
@ -28,7 +28,7 @@ end
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} = stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N = stride = 1, pad = 0, dilation = 1) where N =
@ -55,7 +55,7 @@ end
""" """
MaxPool(k) MaxPool(k)
Maxpooling layer. `k` stands for the size of the window for each dimension of the input. Max pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad` and `stride`.
""" """
@ -63,25 +63,21 @@ struct MaxPool{N}
k::NTuple{N,Int} k::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{N,Int}
stride::NTuple{N,Int} stride::NTuple{N,Int}
MaxPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
end end
function MaxPool{N}(k::Int; pad = 0, stride = k) where N MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
k_ = Tuple(repeat([k, ], N)) MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
MaxPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_))
end
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) (m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MaxPool) function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", ", m.pad, ", ", m.stride, ")") print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end end
""" """
MeanPool(k) MeanPool(k)
Meanpooling layer. `k` stands for the size of the window for each dimension of the input. Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad` and `stride`.
""" """
@ -89,16 +85,13 @@ struct MeanPool{N}
k::NTuple{N,Int} k::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{N,Int}
stride::NTuple{N,Int} stride::NTuple{N,Int}
MeanPool(k::NTuple{N,Int}; pad = map(_->0,k), stride = k) where N = new{N}(k, pad, stride)
end end
function MeanPool{N}(k::Int; pad = 0, stride = k) where N MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
k_ = Tuple(repeat([k, ], N)) MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
MeanPool(k_; pad = map(_->pad,k_), stride=map(_->stride,k_))
end
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) (m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MeanPool) function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", ", m.pad, ", ", m.stride, ")") print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end end

View File

@ -1,21 +1,13 @@
using Test using Test
using Flux: Chain, Conv, MaxPool, MeanPool using Flux: Chain, Conv, MaxPool, MeanPool, maxpool, meanpool
using Base.conv using Base.conv
@testset "pooling" begin @testset "pooling" begin
x = randn(10, 10, 3, 2)
mp = MaxPool((2, 2)) mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
@testset "maxpooling" begin
@test MaxPool{2}(2) == mp
@test MaxPool{2}(2; pad=1, stride=3) == MaxPool((2, 2); pad=(1, 1), stride=(3, 3))
end
mp = MeanPool((2, 2)) mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2))
@testset "meanpooling" begin
@test MeanPool{2}(2) == mp
@test MeanPool{2}(2; pad=1, stride=3) == MeanPool((2, 2); pad=(1, 1), stride=(3, 3))
end
end end
@testset "cnn" begin @testset "cnn" begin