Merge branch 'master' of github.com:jarvist/Flux.jl into HEAD

This commit is contained in:
Jarvist Moore Frost 2018-07-03 11:15:43 +01:00
commit 344a750770
22 changed files with 180 additions and 53 deletions

View File

@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i
```julia
model = Chain(
Dense(768, 128),
Dense(768, 128, σ),
LSTM(128, 256)
LSTM(256, 128)
Dense(128, 10),

View File

@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
```julia
using CuArrays

View File

@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly
```
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.

View File

@ -28,13 +28,15 @@ l = loss(x, y)
back!(l)
```
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
```julia
W.grad
using Flux.Tracker: grad, update!
# Update the parameter
W.data .-= 0.1(W.grad)
Δ = grad(W)
# Update the parameter and reset the gradient
update!(W, -0.1Δ)
loss(x, y) # ~ 2.5
```

View File

@ -7,6 +7,7 @@ add the result to the overall loss.
For example, say we have a simple regression.
```julia
using Flux: crossentropy
m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y)
```
@ -43,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
loss(rand(28^2), rand(10))
```
One can also easily add per-layer regularisation via the `activations` function:
```julia
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
julia> activations(c, rand(10))
3-element Array{Any,1}:
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
param([0.0330606, -0.456104])
param([0.61991, 0.38009])
julia> sum(vecnorm, ans)
2.639678767773633 (tracked)
```

View File

@ -17,16 +17,17 @@ back!(l)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia
function update()
using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate
for p in (W, b)
p.data .-= η .* p.grad # Apply the update
p.grad .= 0 # Clear the gradient
update!(p, -η * grad(p))
end
end
```
If we call `update`, the parameters `W` and `b` will change and our loss should go down.
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.

View File

@ -23,7 +23,7 @@ include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
include("utils.jl")
include("onehot.jl")

View File

@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain)
print(io, ")")
end
# Seem to need this for `accumulate`; try removing on 0.7
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
"""
Dense(in::Integer, out::Integer, σ = identity)

View File

@ -1,5 +1,10 @@
using NNlib: conv
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
"""
Conv(size, in=>out)
Conv(size, in=>out, relu)
@ -10,7 +15,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
σ::F
@ -18,17 +23,17 @@ struct Conv{N,F,A,V}
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where T =
Conv(σ, w, b, stride, pad)
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
stride = 1, pad = 0, dilation = 1) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad)
stride = stride, pad = pad, dilation = dilation)
Flux.treelike(Conv)
@ -36,7 +41,7 @@ function (c::Conv)(x)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end
function Base.show(io::IO, l::Conv)

View File

@ -31,15 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x)
a.active || return x
y = similar(x)
rand!(y)
q = 1 - a.p
@inbounds for i=1:length(y)
y[i] = y[i] > a.p ? 1 / q : 0
end
return y .* x
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
end
_testmode!(a::Dropout, test) = (a.active = !test)

View File

@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
end
"""
binarycrossentropy(, y)
binarycrossentropy(, y; ϵ=eps())
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}:
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
0.352317
0.86167
"""
binarycrossentropy(, y) = -y*log() - (1 - y)*log(1 - )
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
"""
logitbinarycrossentropy(logŷ, y)

View File

@ -1,7 +1,8 @@
module Optimise
export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T}
x::T

View File

@ -99,3 +99,12 @@ tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -35,7 +35,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
@. p.Δ *= η / (acc + ϵ)
end
end
@ -43,7 +43,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
function ()
@. acc += p.Δ^2
@. p.Δ *= η / acc
@. p.Δ *= η / (acc + ϵ)
end
end
@ -64,7 +64,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
@ -94,6 +94,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
end
end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)

View File

@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
@interrupts back!(l)
opt()
cb() == :stop && break

View File

@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x))
grad(::Void) = nothing
struct Call{F,As<:Tuple}
func::F
@ -46,11 +47,27 @@ isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad
function update!(x, Δ)
tracker(x).data += Δ
tracker(x).grad .= 0
return x
end
include("back.jl")
include("scalar.jl")
include("array.jl")
include("numeric.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))

View File

@ -93,6 +93,26 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
back(xs, Δ′)
end
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
Δ′ = similar(xs.data)
Δ′ .= 0
S = size(xs.data)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
back(xs, Δ′)
end
for f in [:vcat, :hcat]
@eval begin
# This section is a bit of a hack since julia doesn't have a standardised
@ -314,18 +334,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
track(_conv, x, w, stride, pad)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad, dilation)
function back(::typeof(_conv), Δ, x, w, stride, pad)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
end
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)

View File

@ -1,4 +1,4 @@
function gradient(f, xs::AbstractArray...)
function gradient(f, xs...)
xs = param.(xs)
back!(f(xs...))
grad.(xs)

View File

@ -8,7 +8,11 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
back!(x::TrackedReal) = back!(x, 1)
function back!(x::TrackedReal)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
end
function Base.show(io::IO, x::TrackedReal)
show(io, data(x))
@ -19,14 +23,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
Base.eps(x::TrackedReal) = eps(data(x))
for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end
@ -91,3 +97,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, xs), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back(::typeof(collect), Δ, xs)
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
end

View File

@ -2,6 +2,8 @@ using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy
const ϵ = 1e-7
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
end
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y)
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end
end

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck
using Flux.Tracker: TrackedReal, gradcheck, grad
using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -114,6 +114,9 @@ end
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron, rand(5,1), rand(3,1))
@ -220,4 +223,13 @@ b = param(rand())
Tracker.back!(b)
@test Tracker.grad(b) == 1
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
end
end #testset