Merge branch 'master' into depthwiseconv

This commit is contained in:
Avik Pal 2018-06-09 11:06:07 +05:30 committed by GitHub
commit 7f3d11cae0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 103 additions and 30 deletions

View File

@ -7,6 +7,7 @@ add the result to the overall loss.
For example, say we have a simple regression.
```julia
using Flux: crossentropy
m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y)
```

View File

@ -23,7 +23,7 @@ include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
include("utils.jl")
include("onehot.jl")

View File

@ -10,7 +10,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
σ::F
@ -18,17 +18,19 @@ struct Conv{N,F,A,V}
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where T =
Conv(σ, w, b, stride, pad)
stride = 1, pad = 0, dilation=1) where T =
Conv(σ, w, b, stride, pad, dilation)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
pad::NTuple{N,Integer} = map(_->0,k),
dilation::NTuple{N,Integer} = map(_->0,k)) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad)
stride = stride, pad = pad, dilation = dilation)
Flux.treelike(Conv)
@ -36,7 +38,7 @@ function (c::Conv)(x)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end
function Base.show(io::IO, l::Conv)

View File

@ -31,15 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x)
a.active || return x
y = similar(x)
rand!(y)
q = 1 - a.p
@inbounds for i=1:length(y)
y[i] = y[i] > a.p ? 1 / q : 0
end
return y .* x
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
end
_testmode!(a::Dropout, test) = (a.active = !test)

View File

@ -1,7 +1,8 @@
module Optimise
export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T}
x::T

View File

@ -91,3 +91,12 @@ tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -27,7 +27,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
@. p.Δ *= η / (acc + ϵ)
end
end
@ -35,7 +35,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
function ()
@. acc += p.Δ^2
@. p.Δ *= η / acc
@. p.Δ *= η / (acc + ϵ)
end
end
@ -56,7 +56,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
@ -86,6 +86,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
end
end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)

View File

@ -93,6 +93,26 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
back(xs, Δ′)
end
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
Δ′ = similar(xs.data)
Δ′ .= 0
S = size(xs.data)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
back(xs, Δ′)
end
for f in [:vcat, :hcat]
@eval begin
# This section is a bit of a hack since julia doesn't have a standardised
@ -314,18 +334,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
function back(::typeof(_conv), Δ, x, w, stride, pad)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
end
_depthwiseconv(x, w, stride, pad) = depthwiseconv(x, w, stride = stride, pad = pad)

View File

@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
# This cuts derivatives, fix if needed.
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, xs), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back(::typeof(collect), Δ, xs)
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
end

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck
using Flux.Tracker: TrackedReal, gradcheck, grad
using NNlib: conv, depthwiseconv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -114,6 +114,9 @@ end
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron, rand(5,1), rand(3,1))
@ -222,4 +225,13 @@ b = param(rand())
Tracker.back!(b)
@test Tracker.grad(b) == 1
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
end
end #testset