Merge pull request #128 from FluxML/conv

Convolution API
This commit is contained in:
Mike J Innes 2017-12-18 18:09:27 +00:00 committed by GitHub
commit 6b6974e14a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 86 additions and 11 deletions

View File

@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks.
```@docs
Chain
Dense
Conv2D
```
## Recurrent Layers

View File

@ -7,13 +7,14 @@ module Flux
using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM,
export Chain, Dense, RNN, LSTM, Conv2D,
Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves
using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl")
using .Tracker
@ -27,6 +28,7 @@ include("treelike.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalisation.jl")

33
src/layers/conv.jl Normal file
View File

@ -0,0 +1,33 @@
"""
Conv2D(size, in=>out)
Conv2d(size, in=>out, relu)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in HWCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct Conv2D{F,A}
σ::F
weight::A
stride::Int
pad::Int
end
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = initn, stride = 1, pad = 0) =
Conv2D(σ, param(initn(k..., ch...)), stride, pad)
Flux.treelike(Conv2D)
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad))
function Base.show(io::IO, l::Conv2D)
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")")
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

View File

@ -18,7 +18,9 @@ end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]

View File

@ -12,16 +12,17 @@ function scan(x::TrackedArray)
return
end
back(c::Call, Δ) = back(c.func, Δ, c.args...)
back(::Call{Void}, Δ) = nothing
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back(x::TrackedArray, Δ)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
ref == 0 && back(x.f, x.grad)
ref == 0 && back_(x.f, x.data, x.grad)
else
ref == 0 && back(x.f, Δ)
ref == 0 && back_(x.f, x.data, Δ)
end
return
end
@ -35,6 +36,9 @@ end
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
function back!(x::TrackedArray, Δ)
scan(x)
back(x, Δ)

View File

@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys)
@back(ys, Δ[size(xs,1)+1:end, i...])
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
TrackedArray(Call(reshape, xs, dims...))
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
# Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
@ -123,12 +129,36 @@ end
# NNlib
import NNlib: softmax, ∇softmax
using NNlib
import NNlib: softmax, ∇softmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
end
_pool(x, k, mode) = pool(x, window = k, mode = mode)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) =
TrackedArray(Call(_pool, x, window, mode))
back_(::typeof(_pool), y, Δ, x, k, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode))
# Broadcasting
using ForwardDiff: Dual, partials

View File

@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
return grads
end
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))

View File

@ -4,8 +4,6 @@ initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
flatten(xs) = reshape(xs, size(xs, 1), :)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)

View File

@ -1,5 +1,6 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: gradcheck
using NNlib
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@ -45,4 +46,8 @@ end
2y + x
end
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
end #testset