Merge pull request #128 from FluxML/conv

Convolution API
This commit is contained in:
Mike J Innes 2017-12-18 18:09:27 +00:00 committed by GitHub
commit 6b6974e14a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 86 additions and 11 deletions

View File

@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks.
```@docs ```@docs
Chain Chain
Dense Dense
Conv2D
``` ```
## Recurrent Layers ## Recurrent Layers

View File

@ -7,13 +7,14 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, export Chain, Dense, RNN, LSTM, Conv2D,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves param, params, mapleaves
using NNlib using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker
@ -27,6 +28,7 @@ include("treelike.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/basic.jl") include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl") include("layers/recurrent.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")

33
src/layers/conv.jl Normal file
View File

@ -0,0 +1,33 @@
"""
Conv2D(size, in=>out)
Conv2d(size, in=>out, relu)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in HWCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct Conv2D{F,A}
σ::F
weight::A
stride::Int
pad::Int
end
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = initn, stride = 1, pad = 0) =
Conv2D(σ, param(initn(k..., ch...)), stride, pad)
Flux.treelike(Conv2D)
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad))
function Base.show(io::IO, l::Conv2D)
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")")
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

View File

@ -18,7 +18,9 @@ end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]

View File

@ -12,16 +12,17 @@ function scan(x::TrackedArray)
return return
end end
back(c::Call, Δ) = back(c.func, Δ, c.args...) back_(f, y, args...) = back(f, args...)
back(::Call{Void}, Δ) = nothing back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back(x::TrackedArray, Δ) function back(x::TrackedArray, Δ)
ref = x.ref -= 1 ref = x.ref -= 1
if isdefined(x, :grad) if isdefined(x, :grad)
x.grad .+= Δ x.grad .+= Δ
ref == 0 && back(x.f, x.grad) ref == 0 && back_(x.f, x.data, x.grad)
else else
ref == 0 && back(x.f, Δ) ref == 0 && back_(x.f, x.data, Δ)
end end
return return
end end
@ -35,6 +36,9 @@ end
# Interface methods # Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
function back!(x::TrackedArray, Δ) function back!(x::TrackedArray, Δ)
scan(x) scan(x)
back(x, Δ) back(x, Δ)

View File

@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys)
@back(ys, Δ[size(xs,1)+1:end, i...]) @back(ys, Δ[size(xs,1)+1:end, i...])
end end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
TrackedArray(Call(reshape, xs, dims...))
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
# Reductions # Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
@ -123,12 +129,36 @@ end
# NNlib # NNlib
import NNlib: softmax, ∇softmax using NNlib
import NNlib: softmax, ∇softmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
end
_pool(x, k, mode) = pool(x, window = k, mode = mode)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) =
TrackedArray(Call(_pool, x, window, mode))
back_(::typeof(_pool), y, Δ, x, k, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode))
# Broadcasting # Broadcasting
using ForwardDiff: Dual, partials using ForwardDiff: Dual, partials

View File

@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
return grads return grads
end end
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6)) gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))

View File

@ -4,8 +4,6 @@ initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims))) glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims))) glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
flatten(xs) = reshape(xs, size(xs, 1), :)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)

View File

@ -1,5 +1,6 @@
using Flux.Tracker, Base.Test, NNlib using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: gradcheck using Flux.Tracker: gradcheck
using NNlib
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...) gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...) gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@ -45,4 +46,8 @@ end
2y + x 2y + x
end end
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
end #testset end #testset