From cd6a0856d5dc06694d6e39f2fb62bc53c6698c4f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 30 May 2018 15:53:57 +0530 Subject: [PATCH] Adds support for Depthwise Convolutions --- src/layers/conv.jl | 46 +++++++++++++++++++++++++++++++++++++++++++- src/tracker/array.jl | 16 ++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 994648c2..237b5b7c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,4 +1,4 @@ -using NNlib: conv +using NNlib: conv, depthwiseconv """ Conv(size, in=>out) @@ -46,5 +46,49 @@ function Base.show(io::IO, l::Conv) print(io, ")") end +""" + DepthwiseConv(size, in=>mul) + DepthwiseConv(size, in=>mul, relu) + +Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`. +`in` and `mul` specify the number of input channels and channel multiplier respectively. + +Data should be stored in WHCN order. In other words, a 100×100 RGB image would +be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. + +Takes the keyword arguments `pad` and `stride`. +""" +struct DepthwiseConv{N,F,A,V} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{N,Int} +end + +DepthwiseConv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0) where T = + DepthwiseConv(σ, w, b, stride, pad) + +DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, + stride::NTuple{N,Integer} = map(_->1,k), + pad::NTuple{N,Integer} = map(_->0,k)) where N = + DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ, + stride = stride, pad = pad) + +Flux.treelike(DepthwiseConv) + +function (c::DepthwiseConv)(x) + σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) +end + +function Base.show(io::IO, l::Conv) + print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2]) + print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end + # v0.5 @deprecate Conv2D(args...; kw...) Conv(args...; kw...) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index bb55ef73..0bc63c63 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -234,7 +234,7 @@ end # NNlib using NNlib -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool +import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool softmax(xs::TrackedArray) = track(softmax, xs) @@ -259,6 +259,20 @@ function back(::typeof(_conv), Δ, x, w, stride, pad) @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad)) end +_depthwiseconv(x, w, stride, pad) = depthwiseconv(x, w, stride = stride, pad = pad) + +depthwiseconv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = + track(_depthwiseconv, x, w, stride, pad) +depthwiseconv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = + track(_depthwiseconv, x, w, stride, pad) +depthwiseconv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N = + track(_depthwiseconv, x, w, stride, pad) + +function back(::typeof(_depthwiseconv), Δ, x, w, stride, pad) + @back(x, NNlib.∇depthwiseconv_data(Δ, data(x), data(w), stride = stride, pad = pad)) + @back(x, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad)) +end + _maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =