Merge pull request #718 from FluxML/sf/asymmetric_padding
Add asymmetric padding
This commit is contained in:
commit
13cfcb5ffa
|
@ -27,6 +27,12 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
|
|||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.3"
|
||||
|
||||
[[CSTParser]]
|
||||
deps = ["LibGit2", "Test", "Tokenize"]
|
||||
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
|
||||
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
||||
version = "0.5.2"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
|
||||
|
@ -53,9 +59,15 @@ version = "0.2.0"
|
|||
|
||||
[[Compat]]
|
||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea"
|
||||
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "2.0.0"
|
||||
version = "2.1.0"
|
||||
|
||||
[[Crayons]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
||||
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||
version = "4.0.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
|
@ -84,7 +96,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
|||
version = "0.0.10"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
|
@ -100,7 +112,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
|||
version = "0.10.3"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
|
@ -123,10 +135,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
|||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
|
||||
[[MacroTools]]
|
||||
deps = ["Compat"]
|
||||
git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1"
|
||||
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
|
||||
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.4.5"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
|
@ -148,10 +160,10 @@ version = "0.4.0"
|
|||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "d07ac0bfd3c71c3a29bc9c22becbba19227bbeb5"
|
||||
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
|
||||
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
|
||||
[[NaNMath]]
|
||||
deps = ["Compat"]
|
||||
|
@ -161,9 +173,9 @@ version = "0.3.2"
|
|||
|
||||
[[OrderedCollections]]
|
||||
deps = ["Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
|
||||
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.0.2"
|
||||
version = "1.1.0"
|
||||
|
||||
[[Pkg]]
|
||||
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
|
@ -237,26 +249,38 @@ deps = ["LinearAlgebra", "SparseArrays"]
|
|||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
|
||||
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
|
||||
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.29.0"
|
||||
version = "0.30.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TimerOutputs]]
|
||||
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
||||
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Tokenize]]
|
||||
deps = ["Printf", "Test"]
|
||||
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
|
||||
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
|
||||
version = "0.5.3"
|
||||
|
||||
[[Tracker]]
|
||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||
git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df"
|
||||
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
|
||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "f42956022d8084539f1d7219f632542b0ea686ce"
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.9.3"
|
||||
version = "0.9.4"
|
||||
|
||||
[[URIParser]]
|
||||
deps = ["Test", "Unicode"]
|
||||
|
@ -265,7 +289,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
|||
version = "0.4.0"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random"]
|
||||
deps = ["Random", "SHA"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
|
@ -273,6 +297,6 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
|||
|
||||
[[ZipFile]]
|
||||
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
|
||||
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
|
|
10
Project.toml
10
Project.toml
|
@ -22,13 +22,13 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
|||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
|
||||
[compat]
|
||||
NNlib = "0.6"
|
||||
Tracker = "0.2"
|
||||
julia = "0.7, 1"
|
||||
|
||||
[extras]
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[targets]
|
||||
test = ["Test"]
|
||||
|
||||
[compat]
|
||||
julia = "0.7, 1"
|
||||
NNlib = "0.5"
|
||||
Tracker = "0.1"
|
||||
|
|
|
@ -63,6 +63,12 @@ git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
|||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||
version = "2.1.0"
|
||||
|
||||
[[Crayons]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
||||
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||
version = "4.0.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
|
||||
|
@ -100,10 +106,10 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
|
|||
version = "0.7.0"
|
||||
|
||||
[[Documenter]]
|
||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
||||
git-tree-sha1 = "a8c41ba3d0861240dbec942ee1d0f86c57c37c1c"
|
||||
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"]
|
||||
git-tree-sha1 = "13a6d15102410d8e70146533b759fc48d844a1d0"
|
||||
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
version = "0.21.5"
|
||||
version = "0.22.3"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
deps = ["Test"]
|
||||
|
@ -127,6 +133,12 @@ version = "0.10.3"
|
|||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[JSON]]
|
||||
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
|
||||
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
|
||||
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
|
||||
version = "0.20.0"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
|
||||
|
@ -172,10 +184,10 @@ version = "0.4.0"
|
|||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd"
|
||||
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
|
||||
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
|
||||
[[NaNMath]]
|
||||
deps = ["Compat"]
|
||||
|
@ -270,6 +282,12 @@ version = "0.30.0"
|
|||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TimerOutputs]]
|
||||
deps = ["Crayons", "Printf", "Test", "Unicode"]
|
||||
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
|
||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Tokenize]]
|
||||
deps = ["Printf", "Test"]
|
||||
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
|
||||
|
@ -278,9 +296,9 @@ version = "0.5.3"
|
|||
|
||||
[[Tracker]]
|
||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||
git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df"
|
||||
git-tree-sha1 = "0bec1b68c63a0e8a58d3944261cbf4cc9577c8a1"
|
||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Random", "Test"]
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||
|
||||
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
||||
"""
|
||||
Conv(size, in=>out)
|
||||
Conv(size, in=>out, relu)
|
||||
|
@ -26,18 +23,22 @@ and a batch of 50 would be a `100×100×3×50` array.
|
|||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct Conv{N,F,A,V}
|
||||
struct Conv{N,M,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
pad = expand(Val(2*(N-2)), pad)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
return Conv(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
|
@ -50,7 +51,8 @@ function (c::Conv)(x::AbstractArray)
|
|||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||
σ.(conv(x, c.weight, cdims) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Conv)
|
||||
|
@ -76,18 +78,22 @@ Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
|||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct ConvTranspose{N,F,A,V}
|
||||
struct ConvTranspose{N,M,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
function ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
pad = expand(Val(2*(N-2)), pad)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
return ConvTranspose(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
|
@ -96,10 +102,25 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
|||
|
||||
@treelike ConvTranspose
|
||||
|
||||
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
|
||||
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
|
||||
C_in = size(c.weight)[end-1]
|
||||
batch_size = size(x)[end]
|
||||
# Create DenseConvDims() that looks like the corresponding conv()
|
||||
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
||||
stride=c.stride,
|
||||
padding=c.pad,
|
||||
dilation=c.dilation,
|
||||
)
|
||||
end
|
||||
|
||||
function (c::ConvTranspose)(x::AbstractArray)
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
cdims = conv_transpose_dims(c, x)
|
||||
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::ConvTranspose)
|
||||
|
@ -128,26 +149,32 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
|||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct DepthwiseConv{N,F,A,V}
|
||||
struct DepthwiseConv{N,M,F,A,V}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0) where {T,N} =
|
||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
||||
function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
pad = expand(Val(2*(N-2)), pad)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
return DepthwiseConv(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0) where N =
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||
stride = stride, pad = pad)
|
||||
stride = stride, pad = pad, dilation=dilation)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
pad::NTuple{N,Integer} = map(_->0,2 .* k),
|
||||
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
|
||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
|
@ -155,7 +182,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
|
|||
|
||||
function (c::DepthwiseConv)(x)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
||||
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::DepthwiseConv)
|
||||
|
@ -178,16 +206,23 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
|
|||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct MaxPool{N}
|
||||
struct MaxPool{N,M}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||
stride = expand(Val(N), stride)
|
||||
pad = expand(Val(2*N), pad)
|
||||
|
||||
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
return MaxPool(k, pad, stride)
|
||||
end
|
||||
|
||||
function (m::MaxPool)(x)
|
||||
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||
return maxpool(x, pdims)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, m::MaxPool)
|
||||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
|
@ -200,16 +235,22 @@ Mean pooling layer. `k` stands for the size of the window for each dimension of
|
|||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct MeanPool{N}
|
||||
struct MeanPool{N,M}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
pad::NTuple{M,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
|
||||
stride = expand(Val(N), stride)
|
||||
pad = expand(Val(2*N), pad)
|
||||
return MeanPool(k, pad, stride)
|
||||
end
|
||||
|
||||
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
function (m::MeanPool)(x)
|
||||
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||
return meanpool(x, pdims)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, m::MeanPool)
|
||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
|
|
|
@ -4,9 +4,9 @@ using Flux: maxpool, meanpool
|
|||
@testset "Pooling" begin
|
||||
x = randn(Float32, 10, 10, 3, 2)
|
||||
mp = MaxPool((2, 2))
|
||||
@test mp(x) == maxpool(x, (2,2))
|
||||
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
||||
mp = MeanPool((2, 2))
|
||||
@test mp(x) == meanpool(x, (2,2))
|
||||
@test mp(x) == meanpool(x, PoolDims(x, 2))
|
||||
end
|
||||
|
||||
@testset "CNN" begin
|
||||
|
@ -22,15 +22,26 @@ end
|
|||
@test size(m(r)) == (10, 5)
|
||||
end
|
||||
|
||||
@testset "asymmetric padding" begin
|
||||
r = ones(Float32, 28, 28, 1, 1)
|
||||
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
|
||||
m.weight.data[:] .= 1.0
|
||||
m.bias.data[:] .= 0.0
|
||||
y_hat = Flux.data(m(r))[:,:,1,1]
|
||||
@test size(y_hat) == (27, 29)
|
||||
@test y_hat[1, 1] ≈ 6.0
|
||||
@test y_hat[2, 2] ≈ 9.0
|
||||
@test y_hat[end, 1] ≈ 4.0
|
||||
@test y_hat[1, end] ≈ 3.0
|
||||
@test y_hat[1, end-1] ≈ 6.0
|
||||
@test y_hat[end, end] ≈ 2.0
|
||||
end
|
||||
|
||||
@testset "Depthwise Conv" begin
|
||||
r = zeros(Float32, 28, 28, 3, 5)
|
||||
|
||||
m1 = DepthwiseConv((2, 2), 3=>5)
|
||||
|
||||
@test size(m1(r), 3) == 15
|
||||
|
||||
m2 = DepthwiseConv((2, 2), 3)
|
||||
|
||||
@test size(m2(r), 3) == 3
|
||||
|
||||
x = zeros(Float64, 28, 28, 3, 5)
|
||||
|
@ -43,3 +54,10 @@ end
|
|||
|
||||
@test size(m4(r), 3) == 3
|
||||
end
|
||||
|
||||
@testset "ConvTranspose" begin
|
||||
x = zeros(Float32, 28, 28, 1, 1)
|
||||
y = Conv((3,3), 1 => 1)(x)
|
||||
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||
@test size(x_hat) == size(x)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue