Merge pull request #279 from avik-pal/depthwiseconv

Adds support for Depthwise Convolutions
This commit is contained in:
Mike J Innes 2018-10-23 17:22:15 +01:00 committed by GitHub
commit bbccdb3eec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 3 deletions

View File

@ -10,6 +10,12 @@ MaxPool
MeanPool MeanPool
``` ```
## Additional Convolution Layers
```@docs
DepthwiseConv
```
## Recurrent Layers ## Recurrent Layers
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data). Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).

View File

@ -1,4 +1,4 @@
using NNlib: conv using NNlib: conv, depthwiseconv
@generated sub2(::Val{N}) where N = :(Val($(N-2))) @generated sub2(::Val{N}) where N = :(Val($(N-2)))
@ -51,6 +51,56 @@ function Base.show(io::IO, l::Conv)
print(io, ")") print(io, ")")
end end
"""
DepthwiseConv(size, in)
DepthwiseConv(size, in=>mul)
DepthwiseConv(size, in=>mul, relu)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `mul` specify the number of input channels and channel multiplier respectively.
In case the `mul` is not specified it is taken as 1.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct DepthwiseConv{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
end
DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
stride = 1, pad = 0) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad)
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
stride = stride, pad = pad)
@treelike DepthwiseConv
function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
end
function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
""" """
MaxPool(k) MaxPool(k)

View File

@ -330,7 +330,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
# NNlib # NNlib
using NNlib using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs) softmax(xs::TrackedArray) = track(softmax, xs)
@ -340,6 +340,16 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) @grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
@grad depthwiseconv(x, w; kw...) =
depthwiseconv(data(x), data(w); kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)

View File

@ -1,7 +1,7 @@
using Flux using Flux
using Flux.Tracker, Test, NNlib using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv using NNlib: conv, depthwiseconv
using Printf: @sprintf using Printf: @sprintf
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
using Statistics: mean, std using Statistics: mean, std
@ -181,6 +181,8 @@ end
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))