Merge pull request #718 from FluxML/sf/asymmetric_padding

Add asymmetric padding
This commit is contained in:
Mike J Innes 2019-04-25 22:29:14 +01:00 committed by GitHub
commit 13cfcb5ffa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 177 additions and 76 deletions

View File

@ -27,6 +27,12 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
@ -53,9 +59,15 @@ version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea"
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.0.0"
version = "2.1.0"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
@ -84,7 +96,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"
[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]]
@ -100,7 +112,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
@ -123,10 +135,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1"
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.5"
version = "0.5.0"
[[Markdown]]
deps = ["Base64"]
@ -148,10 +160,10 @@ version = "0.4.0"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "d07ac0bfd3c71c3a29bc9c22becbba19227bbeb5"
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.5.0"
version = "0.6.0"
[[NaNMath]]
deps = ["Compat"]
@ -161,9 +173,9 @@ version = "0.3.2"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
version = "1.1.0"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@ -237,26 +249,38 @@ deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.29.0"
version = "0.30.0"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df"
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.1.0"
version = "0.2.0"
[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "f42956022d8084539f1d7219f632542b0ea686ce"
deps = ["Random", "Test"]
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.3"
version = "0.9.4"
[[URIParser]]
deps = ["Test", "Unicode"]
@ -265,7 +289,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random"]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
@ -273,6 +297,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"
version = "0.8.1"

View File

@ -22,13 +22,13 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
[compat]
NNlib = "0.6"
Tracker = "0.2"
julia = "0.7, 1"
[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
test = ["Test"]
[compat]
julia = "0.7, 1"
NNlib = "0.5"
Tracker = "0.1"

View File

@ -63,6 +63,12 @@ git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
@ -100,10 +106,10 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.7.0"
[[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "a8c41ba3d0861240dbec942ee1d0f86c57c37c1c"
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.21.5"
version = "0.22.3"
[[FixedPointNumbers]]
deps = ["Test"]
@ -127,6 +133,12 @@ version = "0.10.3"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[JSON]]
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.20.0"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
@ -172,10 +184,10 @@ version = "0.4.0"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd"
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.5.0"
version = "0.6.0"
[[NaNMath]]
deps = ["Compat"]
@ -270,6 +282,12 @@ version = "0.30.0"
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
@ -278,9 +296,9 @@ version = "0.5.3"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df"
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.1.0"
version = "0.2.0"
[[TranscodingStreams]]
deps = ["Random", "Test"]

View File

@ -1,10 +1,7 @@
using NNlib: conv, ∇conv_data, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
"""
Conv(size, in=>out)
Conv(size, in=>out, relu)
@ -26,18 +23,22 @@ and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
struct Conv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return Conv(σ, w, b, stride, pad, dilation)
end
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
@ -50,7 +51,8 @@ function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(conv(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::Conv)
@ -76,18 +78,22 @@ Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,F,A,V}
struct ConvTranspose{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
function ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return ConvTranspose(σ, w, b, stride, pad, dilation)
end
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
@ -96,10 +102,25 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
@treelike ConvTranspose
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
C_in = size(c.weight)[end-1]
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
stride=c.stride,
padding=c.pad,
dilation=c.dilation,
)
end
function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
cdims = conv_transpose_dims(c, x)
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::ConvTranspose)
@ -128,26 +149,32 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct DepthwiseConv{N,F,A,V}
struct DepthwiseConv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return DepthwiseConv(σ, w, b, stride, pad, dilation)
end
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
stride = 1, pad = 0) where N =
stride = 1, pad = 0, dilation = 1) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad)
stride = stride, pad = pad, dilation=dilation)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
pad::NTuple{N,Integer} = map(_->0,2 .* k),
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
stride = stride, pad = pad)
@ -155,7 +182,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::DepthwiseConv)
@ -178,16 +206,23 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
Takes the keyword arguments `pad` and `stride`.
"""
struct MaxPool{N}
struct MaxPool{N,M}
k::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
stride::NTuple{N,Int}
end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
stride = expand(Val(N), stride)
pad = expand(Val(2*N), pad)
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
return MaxPool(k, pad, stride)
end
function (m::MaxPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return maxpool(x, pdims)
end
function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
@ -200,16 +235,22 @@ Mean pooling layer. `k` stands for the size of the window for each dimension of
Takes the keyword arguments `pad` and `stride`.
"""
struct MeanPool{N}
struct MeanPool{N,M}
k::NTuple{N,Int}
pad::NTuple{N,Int}
pad::NTuple{M,Int}
stride::NTuple{N,Int}
end
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
stride = expand(Val(N), stride)
pad = expand(Val(2*N), pad)
return MeanPool(k, pad, stride)
end
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
function (m::MeanPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return meanpool(x, pdims)
end
function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")

View File

@ -4,9 +4,9 @@ using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
@test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2))
@test mp(x) == meanpool(x, PoolDims(x, 2))
end
@testset "CNN" begin
@ -22,15 +22,26 @@ end
@test size(m(r)) == (10, 5)
end
@testset "asymmetric padding" begin
r = ones(Float32, 28, 28, 1, 1)
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
m.weight.data[:] .= 1.0
m.bias.data[:] .= 0.0
y_hat = Flux.data(m(r))[:,:,1,1]
@test size(y_hat) == (27, 29)
@test y_hat[1, 1] 6.0
@test y_hat[2, 2] 9.0
@test y_hat[end, 1] 4.0
@test y_hat[1, end] 3.0
@test y_hat[1, end-1] 6.0
@test y_hat[end, end] 2.0
end
@testset "Depthwise Conv" begin
r = zeros(Float32, 28, 28, 3, 5)
m1 = DepthwiseConv((2, 2), 3=>5)
@test size(m1(r), 3) == 15
m2 = DepthwiseConv((2, 2), 3)
@test size(m2(r), 3) == 3
x = zeros(Float64, 28, 28, 3, 5)
@ -43,3 +54,10 @@ end
@test size(m4(r), 3) == 3
end
@testset "ConvTranspose" begin
x = zeros(Float32, 28, 28, 1, 1)
y = Conv((3,3), 1 => 1)(x)
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
@test size(x_hat) == size(x)
end