Added tests for varying padding, stride, and dilation with outdims.
This commit is contained in:
parent
a64378b112
commit
0cdd11c0dc
@ -3,7 +3,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv, output_size
|
|||||||
# pad dims of x with dims of y until ndims(x) == ndims(y)
|
# pad dims of x with dims of y until ndims(x) == ndims(y)
|
||||||
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)
|
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)
|
||||||
|
|
||||||
_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad)
|
_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end])
|
||||||
|
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
@ -161,7 +161,7 @@ end
|
|||||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.pad[1:N])
|
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
DepthwiseConv(size, in=>out)
|
DepthwiseConv(size, in=>out)
|
||||||
|
@ -111,19 +111,51 @@ end
|
|||||||
@testset "conv output dimensions" begin
|
@testset "conv output dimensions" begin
|
||||||
m = Conv((3, 3), 3 => 16)
|
m = Conv((3, 3), 3 => 16)
|
||||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||||
|
m = Conv((3, 3), 3 => 16; stride = 2)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||||
|
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||||
|
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||||
|
|
||||||
m = ConvTranspose((3, 3), 3 => 16)
|
m = ConvTranspose((3, 3), 3 => 16)
|
||||||
@test Flux.outdims(m, (8, 8)) == (10, 10)
|
@test Flux.outdims(m, (8, 8)) == (10, 10)
|
||||||
|
m = ConvTranspose((3, 3), 3 => 16; stride = 2)
|
||||||
|
@test Flux.outdims(m, (2, 2)) == (5, 5)
|
||||||
|
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||||
|
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||||
|
@test Flux.outdims(m, (4, 4)) == (5, 5)
|
||||||
|
|
||||||
m = DepthwiseConv((3, 3), 3 => 6)
|
m = DepthwiseConv((3, 3), 3 => 6)
|
||||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||||
|
m = DepthwiseConv((3, 3), 3 => 6; stride = 2)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||||
|
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||||
|
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||||
|
|
||||||
m = CrossCor((3, 3), 3 => 16)
|
m = CrossCor((3, 3), 3 => 16)
|
||||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||||
|
m = CrossCor((3, 3), 3 => 16; stride = 2)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||||
|
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||||
|
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||||
|
|
||||||
m = MaxPool((2, 2))
|
m = MaxPool((2, 2))
|
||||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||||
|
m = MaxPool((2, 2); stride = 1)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||||
|
m = MaxPool((2, 2); stride = 2, pad = 3)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||||
|
|
||||||
m = MeanPool((2, 2))
|
m = MeanPool((2, 2))
|
||||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||||
|
m = MeanPool((2, 2); stride = 1)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||||
|
m = MeanPool((2, 2); stride = 2, pad = 3)
|
||||||
|
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||||
end
|
end
|
Loading…
Reference in New Issue
Block a user