Update `Flux` code for new NNlib branch

This commit is contained in:
Elliot Saba 2019-03-28 15:56:21 -07:00
parent bc2999b5a7
commit 113ddc8760
3 changed files with 38 additions and 18 deletions

View File

@ -53,9 +53,9 @@ version = "0.2.0"
[[Compat]] [[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea" git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.0.0" version = "2.1.0"
[[DataStructures]] [[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
@ -84,7 +84,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10" version = "0.0.10"
[[Distributed]] [[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]] [[FixedPointNumbers]]
@ -100,7 +100,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3" version = "0.10.3"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]] [[Juno]]
@ -149,7 +149,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]] [[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "d07ac0bfd3c71c3a29bc9c22becbba19227bbeb5" git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.5.0" version = "0.5.0"
@ -265,7 +265,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0" version = "0.4.0"
[[UUIDs]] [[UUIDs]]
deps = ["Random"] deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]] [[Unicode]]

View File

@ -50,7 +50,8 @@ function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :( # TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(conv(x, c.weight, cdims) .+ b)
end end
function Base.show(io::IO, l::Conv) function Base.show(io::IO, l::Conv)
@ -99,7 +100,17 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
function (c::ConvTranspose)(x::AbstractArray) function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) # Calculate size of "input", from ∇conv_data()'s perspective...
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- 2 .* c.pad
C_in = size(c.weight)[end-1]
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
cdims = DenseConvDims((I..., C_in, batch_size), size(c.weight);
stride=c.stride,
padding=c.pad,
dilation=c.dilation,
)
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
end end
function Base.show(io::IO, l::ConvTranspose) function Base.show(io::IO, l::ConvTranspose)
@ -134,20 +145,22 @@ struct DepthwiseConv{N,F,A,V}
bias::V bias::V
stride::NTuple{N,Int} stride::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end end
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} = stride = 1, pad = 0, dilation = 1) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...) DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform, DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
stride = 1, pad = 0) where N = stride = 1, pad = 0, dilation = 1) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ, DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad) stride = stride, pad = pad, dilation=dilation)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
stride::NTuple{N,Integer} = map(_->1,k), stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N = pad::NTuple{N,Integer} = map(_->0,k),
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ, DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
stride = stride, pad = pad) stride = stride, pad = pad)
@ -155,7 +168,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
function (c::DepthwiseConv)(x) function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
end end
function Base.show(io::IO, l::DepthwiseConv) function Base.show(io::IO, l::DepthwiseConv)
@ -187,7 +201,10 @@ end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride)) MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) function (m::MaxPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return maxpool(x, pdims)
end
function Base.show(io::IO, m::MaxPool) function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
@ -209,7 +226,10 @@ end
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride)) MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) function (m::MeanPool)(x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return meanpool(x, pdims)
end
function Base.show(io::IO, m::MeanPool) function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")

View File

@ -4,9 +4,9 @@ using Flux: maxpool, meanpool
@testset "Pooling" begin @testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2) x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2)) mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2)) @test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2)) mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2)) @test mp(x) == meanpool(x, PoolDims(x, 2))
end end
@testset "CNN" begin @testset "CNN" begin