From e6efca4bf4e90d2e66c630229450277b677f13ed Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Mon, 21 May 2018 12:20:43 -0700 Subject: [PATCH] Add `dilation` kwarg to `Conv` Now that we have dilated convolution support in `NNlib`, this is enables support in Flux's `Conv` layer. --- src/layers/conv.jl | 14 ++++++++------ src/tracker/array.jl | 20 ++++++++++---------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 39d3394d..cae4229d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -10,7 +10,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`. Data should be stored in WHCN order. In other words, a 100×100 RGB image would be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. -Takes the keyword arguments `pad` and `stride`. +Takes the keyword arguments `pad`, `stride` and `dilation`. """ struct Conv{N,F,A,V} σ::F @@ -18,17 +18,19 @@ struct Conv{N,F,A,V} bias::V stride::NTuple{N,Int} pad::NTuple{N,Int} + dilation::NTuple{N,Int} end Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0) where T = - Conv(σ, w, b, stride, pad) + stride = 1, pad = 0, dilation=1) where T = + Conv(σ, w, b, stride, pad, dilation) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride::NTuple{N,Integer} = map(_->1,k), - pad::NTuple{N,Integer} = map(_->0,k)) where N = + pad::NTuple{N,Integer} = map(_->0,k), + dilation::NTuple{N,Integer} = map(_->0,k)) where N = Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, - stride = stride, pad = pad) + stride = stride, pad = pad, dilation = dilation) Flux.treelike(Conv) @@ -36,7 +38,7 @@ function (c::Conv)(x) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) + σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) end function Base.show(io::IO, l::Conv) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index e11296ab..75b1ebb4 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -314,18 +314,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) # TODO: can store kwargs efficiently in namedtuples -_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad) +_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation) -conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) -conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) -conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) +conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) +conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) +conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) -function back(::typeof(_conv), Δ, x, w, stride, pad) - @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad)) - @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad)) +function back(::typeof(_conv), Δ, x, w, stride, pad, dilation) + @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) + @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) end _maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)