Update `Flux` code for new NNlib branch
This commit is contained in:
parent
bc2999b5a7
commit
113ddc8760
|
@ -53,9 +53,9 @@ version = "0.2.0"
|
||||||
|
|
||||||
[[Compat]]
|
[[Compat]]
|
||||||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
|
||||||
git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea"
|
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
|
||||||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
||||||
version = "2.0.0"
|
version = "2.1.0"
|
||||||
|
|
||||||
[[DataStructures]]
|
[[DataStructures]]
|
||||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||||
|
@ -84,7 +84,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||||
version = "0.0.10"
|
version = "0.0.10"
|
||||||
|
|
||||||
[[Distributed]]
|
[[Distributed]]
|
||||||
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
deps = ["Random", "Serialization", "Sockets"]
|
||||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||||
|
|
||||||
[[FixedPointNumbers]]
|
[[FixedPointNumbers]]
|
||||||
|
@ -100,7 +100,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||||
version = "0.10.3"
|
version = "0.10.3"
|
||||||
|
|
||||||
[[InteractiveUtils]]
|
[[InteractiveUtils]]
|
||||||
deps = ["LinearAlgebra", "Markdown"]
|
deps = ["Markdown"]
|
||||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
|
|
||||||
[[Juno]]
|
[[Juno]]
|
||||||
|
@ -149,7 +149,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||||
|
|
||||||
[[NNlib]]
|
[[NNlib]]
|
||||||
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||||
git-tree-sha1 = "d07ac0bfd3c71c3a29bc9c22becbba19227bbeb5"
|
git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd"
|
||||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
|
||||||
|
@ -265,7 +265,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
|
||||||
[[UUIDs]]
|
[[UUIDs]]
|
||||||
deps = ["Random"]
|
deps = ["Random", "SHA"]
|
||||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||||
|
|
||||||
[[Unicode]]
|
[[Unicode]]
|
||||||
|
|
|
@ -50,7 +50,8 @@ function (c::Conv)(x::AbstractArray)
|
||||||
# TODO: breaks gpu broadcast :(
|
# TODO: breaks gpu broadcast :(
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||||
|
σ.(conv(x, c.weight, cdims) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Conv)
|
function Base.show(io::IO, l::Conv)
|
||||||
|
@ -99,7 +100,17 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
||||||
function (c::ConvTranspose)(x::AbstractArray)
|
function (c::ConvTranspose)(x::AbstractArray)
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||||
|
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- 2 .* c.pad
|
||||||
|
C_in = size(c.weight)[end-1]
|
||||||
|
batch_size = size(x)[end]
|
||||||
|
# Create DenseConvDims() that looks like the corresponding conv()
|
||||||
|
cdims = DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
||||||
|
stride=c.stride,
|
||||||
|
padding=c.pad,
|
||||||
|
dilation=c.dilation,
|
||||||
|
)
|
||||||
|
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::ConvTranspose)
|
function Base.show(io::IO, l::ConvTranspose)
|
||||||
|
@ -134,20 +145,22 @@ struct DepthwiseConv{N,F,A,V}
|
||||||
bias::V
|
bias::V
|
||||||
stride::NTuple{N,Int}
|
stride::NTuple{N,Int}
|
||||||
pad::NTuple{N,Int}
|
pad::NTuple{N,Int}
|
||||||
|
dilation::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
|
|
||||||
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
stride = 1, pad = 0) where {T,N} =
|
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||||
|
|
||||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||||
stride = 1, pad = 0) where N =
|
stride = 1, pad = 0, dilation = 1) where N =
|
||||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||||
stride = stride, pad = pad)
|
stride = stride, pad = pad, dilation=dilation)
|
||||||
|
|
||||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||||
stride::NTuple{N,Integer} = map(_->1,k),
|
stride::NTuple{N,Integer} = map(_->1,k),
|
||||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
pad::NTuple{N,Integer} = map(_->0,k),
|
||||||
|
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
|
||||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||||
stride = stride, pad = pad)
|
stride = stride, pad = pad)
|
||||||
|
|
||||||
|
@ -155,7 +168,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
|
||||||
|
|
||||||
function (c::DepthwiseConv)(x)
|
function (c::DepthwiseConv)(x)
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||||
|
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::DepthwiseConv)
|
function Base.show(io::IO, l::DepthwiseConv)
|
||||||
|
@ -187,7 +201,10 @@ end
|
||||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||||
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||||
|
|
||||||
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
function (m::MaxPool)(x)
|
||||||
|
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||||
|
return maxpool(x, pdims)
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, m::MaxPool)
|
function Base.show(io::IO, m::MaxPool)
|
||||||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
|
@ -209,7 +226,10 @@ end
|
||||||
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||||
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||||
|
|
||||||
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
function (m::MeanPool)(x)
|
||||||
|
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
|
||||||
|
return meanpool(x, pdims)
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, m::MeanPool)
|
function Base.show(io::IO, m::MeanPool)
|
||||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
|
|
|
@ -4,9 +4,9 @@ using Flux: maxpool, meanpool
|
||||||
@testset "Pooling" begin
|
@testset "Pooling" begin
|
||||||
x = randn(Float32, 10, 10, 3, 2)
|
x = randn(Float32, 10, 10, 3, 2)
|
||||||
mp = MaxPool((2, 2))
|
mp = MaxPool((2, 2))
|
||||||
@test mp(x) == maxpool(x, (2,2))
|
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
||||||
mp = MeanPool((2, 2))
|
mp = MeanPool((2, 2))
|
||||||
@test mp(x) == meanpool(x, (2,2))
|
@test mp(x) == meanpool(x, PoolDims(x, 2))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "CNN" begin
|
@testset "CNN" begin
|
||||||
|
|
Loading…
Reference in New Issue