Merge pull request #311 from tejank10/conv_transpose

2D Conv transpose support
This commit is contained in:
Mike J Innes 2019-02-06 14:14:14 +00:00 committed by GitHub
commit e8b2ec6f67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 18 deletions

View File

@ -1,5 +1,3 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractTrees]] [[AbstractTrees]]
deps = ["Markdown", "Test"] deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
@ -53,9 +51,9 @@ version = "0.2.0"
[[Compat]] [[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a" git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0" version = "1.5.1"
[[DataStructures]] [[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
@ -84,7 +82,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.8" version = "0.0.8"
[[Distributed]] [[Distributed]]
deps = ["Random", "Serialization", "Sockets"] deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]] [[FixedPointNumbers]]
@ -100,7 +98,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.2" version = "0.10.2"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["Markdown"] deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]] [[Juno]]
@ -149,9 +147,11 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]] [[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d" git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f"
repo-rev = "master"
repo-url = "https://github.com/FluxML/NNlib.jl.git"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3" version = "0.4.3+"
[[NaNMath]] [[NaNMath]]
deps = ["Compat"] deps = ["Compat"]
@ -259,7 +259,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0" version = "0.4.0"
[[UUIDs]] [[UUIDs]]
deps = ["Random", "SHA"] deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]] [[Unicode]]

View File

@ -6,7 +6,7 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm, DepthwiseConv, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu, f32, f64 params, mapleaves, cpu, gpu, f32, f64

View File

@ -1,4 +1,4 @@
using NNlib: conv, depthwiseconv using NNlib: conv, ∇conv_data, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2))) @generated sub2(::Val{N}) where N = :(Val($(N-2)))
@ -57,6 +57,54 @@ end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = (a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x)) a(T.(x))
"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose
function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end
function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
""" """
DepthwiseConv(size, in) DepthwiseConv(size, in)
DepthwiseConv(size, in=>mul) DepthwiseConv(size, in=>mul)

View File

@ -364,7 +364,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
# NNlib # NNlib
using NNlib using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs) softmax(xs::TrackedArray) = track(softmax, xs)
@ -391,8 +391,18 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@grad conv(x, w; kw...) = @grad conv(x, w; kw...) =
conv(data(x), data(w); kw...), conv(data(x), data(w); kw...),
Δ -> nobacksies(:conv, Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...), (NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) NNlib.∇conv_filter(data.((Δ, x))...; size=size(w), kw...)))
∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)
@grad ∇conv_data(x, w; kw...) =
∇conv_data(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.conv(data.((Δ, w))...; size=size(x), kw...),
NNlib.∇conv_filter(data.((x, Δ))...; size=size(w), kw...)))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)

View File

@ -1,7 +1,7 @@
using Flux using Flux
using Flux.Tracker, Test, NNlib using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
using NNlib: conv, depthwiseconv using NNlib: conv, ∇conv_data, depthwiseconv
using Printf: @sprintf using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std using Statistics: mean, std
@ -189,8 +189,16 @@ end
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2))
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3)) @test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))