conv api updates

This commit is contained in:
Mike J Innes 2018-02-26 22:43:07 +00:00
parent 54919b8dca
commit 15d1d3256b
4 changed files with 53 additions and 38 deletions

View File

@ -7,7 +7,7 @@ module Flux
using Juno, Requires, Reexport using Juno, Requires, Reexport
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv2D, export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves param, params, mapleaves

View File

@ -1,6 +1,8 @@
using NNlib: conv
""" """
Conv2D(size, in=>out) Conv(size, in=>out)
Conv2d(size, in=>out, relu) Conv(size, in=>out, relu)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`. Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
@ -10,32 +12,37 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad` and `stride`.
""" """
struct Conv2D{F,A,V} struct Conv{N,F,A,V}
σ::F σ::F
weight::A weight::A
bias::V bias::V
stride::Int stride::NTuple{N,Int}
pad::Int pad::NTuple{N,Int}
end end
Conv2D(w::AbstractArray{T,4}, b::AbstractVector{T}, σ = identity; Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where T = stride = 1, pad = 0) where T =
Conv2D(σ, w, b, stride, pad) Conv(σ, w, b, stride, pad)
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
init = initn, stride = 1, pad = 0) = stride::NTuple{N,Integer} = map(_->1,k),
Conv2D(param(init(k..., ch...)), param(zeros(ch[2])), σ, stride = stride, pad = pad) pad::NTuple{N,Integer} = map(_->0,k)) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad)
Flux.treelike(Conv2D) Flux.treelike(Conv)
function (c::Conv2D)(x) function (c::Conv)(x)
σ, b = c.σ, reshape(c.bias, 1, 1, :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad) .+ b) σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
end end
function Base.show(io::IO, l::Conv2D) function Base.show(io::IO, l::Conv)
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")") print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4)) print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ) l.σ == identity || print(io, ", ", l.σ)
print(io, ")") print(io, ")")
end end
# v0.5
@deprecate Conv2D(args...; kw...) Conv(args...; kw...)

View File

@ -217,7 +217,7 @@ end
# NNlib # NNlib
using NNlib using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs) softmax(xs::TrackedArray) = track(softmax, xs)
@ -228,27 +228,35 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples # TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad) _conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv2d, x, w, stride, padding) track(_conv, x, w, stride, pad)
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv2d, x, w, stride, padding) track(_conv, x, w, stride, pad)
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) = conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv2d, x, w, stride, padding) track(_conv, x, w, stride, pad)
function back(::typeof(_conv2d), Δ, x, w, stride, pad) function back(::typeof(_conv), Δ, x, w, stride, pad)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad)) @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad)) @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
end end
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad) _maxpool(x, k, pad) = maxpool(x, k; pad = pad)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) = maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_pool, x, window, padding, mode) track(_maxpool, x, k, pad)
back_(::typeof(_pool), y, Δ, x, k, pad, mode) = back_(::typeof(_maxpool), y, Δ, x, k, pad) =
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad)) back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_meanpool, x, k, pad)
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
# Broadcasting # Broadcasting

View File

@ -1,6 +1,6 @@
using Flux.Tracker, Base.Test, NNlib using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck using Flux.Tracker: TrackedReal, gradcheck
using NNlib using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...) gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@ -60,9 +60,9 @@ end
2y + x 2y + x
end end
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test (param([1,2,3]) .< 2) == [true, false, false] @test (param([1,2,3]) .< 2) == [true, false, false]