Merge branch 'master' into batchnorm

This commit is contained in:
Mike J Innes 2017-12-08 19:29:49 +00:00
commit 6f997e798a
21 changed files with 288 additions and 106 deletions

View File

@ -151,3 +151,13 @@ m = Chain(x -> x^2, x -> x+1)
m(5) # => 26 m(5) # => 26
``` ```
## Layer helpers
Flux provides a set of helpers for custom layers, which you can enable by calling
```julia
Flux.treelike(Affine)
```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -36,6 +36,8 @@ swish
These layers don't affect the structure of the network but may improve training times or reduce overfitting. These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs ```@docs
Flux.testmode!
BatchNorm BatchNorm
Dropout Dropout
LayerNorm
``` ```

View File

@ -58,8 +58,5 @@ All optimisers return a function that, when called, will update the parameters p
SGD SGD
Momentum Momentum
Nesterov Nesterov
RMSProp
ADAM ADAM
ADAGrad
ADADelta
``` ```

View File

@ -7,12 +7,13 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export BatchNorm, Chain, Dense, RNN, LSTM, Dropout, export Chain, Dense, RNN, LSTM,
SGD, ADAM, Momentum, Nesterov, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves param, params, mapleaves
using NNlib using NNlib
export σ, relu, leakyrelu, elu, swish, softmax export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker
@ -22,7 +23,7 @@ using .Optimise
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")
include("tree.jl") include("treelike.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/basic.jl") include("layers/basic.jl")
@ -31,4 +32,6 @@ include("layers/normalisation.jl")
include("data/Data.jl") include("data/Data.jl")
include("batches/Batches.jl")
end # module end # module

7
src/batches/Batches.jl Normal file
View File

@ -0,0 +1,7 @@
module Batches
import ..Flux
include("batch.jl")
end

8
src/batches/batch.jl Normal file
View File

@ -0,0 +1,8 @@
struct Batch{T,A,M}
data::A
mask::M
end
Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask)
Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))

View File

@ -33,7 +33,7 @@ function rawdict()
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
end end
validword(s) = ismatch(r"^[\w-\.]+$", s) validword(s) = ismatch(r"^[\w\-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict()) cmudict() = filter((s, ps) -> validword(s), rawdict())

View File

@ -78,3 +78,32 @@ function Base.show(io::IO, l::Dense)
l.σ == identity || print(io, ", ", l.σ) l.σ == identity || print(io, ", ", l.σ)
print(io, ")") print(io, ")")
end end
"""
Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable
vectors `α` and `β`:
y = α .* x .+ β
The input `x` must be a array where `size(x, 1) == in`.
"""
struct Diagonal{T}
α::T
β::T
end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in)))
treelike(Diagonal)
function (a::Diagonal)(x)
α, β = a.α, a.β
α.*x .+ β
end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end

View File

@ -44,6 +44,29 @@ end
_testmode!(a::Dropout, test) = (a.active = !test) _testmode!(a::Dropout, test) = (a.active = !test)
"""
LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
each input before applying a per-neuron gain/bias.
"""
struct LayerNorm{T}
diag::Diagonal{T}
end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
treelike(LayerNorm)
(a::LayerNorm)(x) = a.diag(normalise(x))
function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", length(l.diag.α), ")")
end
""" """
BatchNorm(dims...; λ = identity, BatchNorm(dims...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
@ -65,8 +88,6 @@ julia> m = Chain(
BatchNorm(10), BatchNorm(10),
softmax) softmax)
Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax)
julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate
``` ```
""" """
mutable struct BatchNorm{F,V,N} mutable struct BatchNorm{F,V,N}

View File

@ -1,14 +1,27 @@
using NNlib: log_fast
# Cost functions # Cost functions
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) = crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) =
-sum(y .* log.()) / size(y, 2) -sum(y .* log_fast.()) / size(y, 2)
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1) logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log.(sum(exp.(logŷ), 1)) ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2) -sum(y .* ypred) / size(y, 2)
end end
"""
normalise(x::AbstractVecOrMat)
Normalise each column of `x` to mean 0 and standard deviation 1.
"""
function normalise(x::AbstractVecOrMat)
μ′ = mean(x, 1)
σ = std(x, 1, mean = μ′)
return (x .- μ′) ./ σ
end

View File

@ -1,3 +1,5 @@
import Base: *
struct OneHotVector <: AbstractVector{Bool} struct OneHotVector <: AbstractVector{Bool}
ix::UInt32 ix::UInt32
of::UInt32 of::UInt32
@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
height::Int height::Int
@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
@ -40,10 +42,22 @@ function onehot(l, labels)
OneHotVector(i, length(labels)) OneHotVector(i, length(labels))
end end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) function onehot(l, labels, unk)
i = findfirst(labels, l)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))] labels[findfirst(y, maximum(y))]
argmax(y::AbstractMatrix, l...) = argmax(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1) squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
# Ambiguity hack
a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b))
a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b))

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export update!, params, train!, export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T} struct Param{T}
x::T x::T

View File

@ -1,5 +1,7 @@
call(f, xs...) = f(xs...) call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
function optimiser(ps, fs...) function optimiser(ps, fs...)
ps = [Param(p) for p in ps] ps = [Param(p) for p in ps]
fs = map(ps) do p fs = map(ps) do p
@ -10,64 +12,73 @@ function optimiser(ps, fs...)
end end
""" """
SGD(params, η = 1; decay = 0) SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser. For each parameter `p` and its Classic gradient descent optimiser with learning rate `η`.
gradient `δp`, this runs `p -= η*δp`. For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports decayed learning rate decay if the `decay` argument is provided. Supports inverse decaying learning rate if the `decay` argument is provided.
""" """
SGD(ps, η = 1; decay = 0) = SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η)) optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
""" """
Momentum(params, ρ, decay = 0) Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with momentum `ρ` and optional learning rate decay. SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
""" """
Momentum(ps, ρ; decay = 0) = Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
""" """
Nesterov(params, ρ, decay = 0) Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with Nesterov momentum `ρ` and optional learning rate decay. SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
""" """
Nesterov(ps, ρ; decay = 0) = Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
""" """
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0) RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks. choice for recurrent networks.
""" """
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
""" """
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0) ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning. Parameters don't need tuning.
""" """
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) = ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning. tuning.
""" """
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

View File

@ -1,74 +1,97 @@
function descent(p::Param, η::Real) function descent(p::Param, η::Real)
function () function ()
p.x .-= p.Δ .* η @. p.x -= η * p.Δ
p.Δ .= 0 @. p.Δ = 0
end end
end end
function momentum(p::Param, ρ::Real) function momentum(p::Param, ρ, η)
mo = zeros(p.x) v = zeros(p.x)
() -> p.Δ .= mo .= ρ .* mo .+ p.Δ
end
function nesterov(p::Param, ρ::Real)
mo = zeros(p.x)
function () function ()
mo .= ρ .* mo .+ p.Δ @. v = ρ * v - η * p.Δ
p.Δ .= ρ .* mo .+ p.Δ @. p.Δ = -v
end end
end end
function clip(p::Param, thresh::Real) # Ref. https://arxiv.org/pdf/1212.0901.pdf
() -> clamp!(p.Δ, -thresh, thresh) function nesterov(p::Param, ρ, η)
end v = zeros(p.x)
function weightdecay(p::Param, γ::Real)
() -> p.Δ .+= γ .* p.x
end
function invdecay(p::Param, γ::Real)
n = 0
function () function ()
p.Δ .*= 1 / (1 + γ * n) d = @. ρ^2 * v - (1+ρ) * η * p.Δ
n += 1 @. v = ρ*v - η*p.Δ
@. p.Δ = -d
end end
end end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x)
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 @. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ /= acc * η @. p.Δ *= η / (acc + ϵ)
end end
end end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x) .+ ϵ
function () function ()
@. acc += p.Δ ^ 2 @. acc += p.Δ^2
@. p.Δ /= acc * η @. p.Δ *= η / acc
end end
end end
function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8) function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x)
Δacc = zeros(p.x) .+ ϵ Δacc = zeros(p.x)
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 @. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= Δacc / acc @. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2 @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end end
end end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ vt = zeros(p.x)
β1p, β2p = β1, β2 β1p, β2p = β1, β2
function () function ()
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 @. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (1 - β2p) / (1 - β1p) * mt / vt * η @. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
β1p *= β1 β1p *= β1
β2p *= β2 β2p *= β2
end end
end end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
end

View File

@ -1,8 +1,8 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: back!
tocb(f) = f runall(f) = f
tocb(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
""" """
train!(loss, data, opt; cb = () -> ()) train!(loss, data, opt; cb = () -> ())
@ -11,10 +11,11 @@ For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt` and the callback `cb` backpropagation and calls the optimizer `opt` and the callback `cb`
(i.e. `opt()` and `cb()`). (i.e. `opt()` and `cb()`).
Multiple callbacks can be passed to `cb` as an array. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, data, opt; cb = () -> ()) function train!(loss, data, opt; cb = () -> ())
cb = tocb(cb) cb = runall(cb)
opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) l = loss(d...)
isinf(l.data[]) && error("Loss is Inf") isinf(l.data[]) && error("Loss is Inf")

View File

@ -1,6 +1,6 @@
module Tracker module Tracker
export TrackedArray, param, back! export TrackedArray, TrackedVector, TrackedMatrix, param, back!
data(x) = x data(x) = x
istracked(x) = false istracked(x) = false
@ -38,7 +38,9 @@ TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs)) isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(xs::Real) = param(fill(xs)) param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true istracked(x::TrackedArray) = true
@ -56,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
# TODO decide if keeping both data and value. The problem is TrackedScalar
value(x) = x value(x) = x
value(x::TrackedArray) = data(x) value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[] value(x::TrackedScalar) = data(x)[]
@ -67,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
Base.isless(x::TrackedScalar, y) = isless(value(x), y) Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y)) Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}") print(io, "TrackedArray{…,$A}")

View File

@ -1,5 +1,3 @@
import Base: *
toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
@ -60,25 +58,55 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data)) back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) = back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
# BLAS # BLAS
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) for f in :[*, Ac_mul_B].args
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) @eval begin
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) $f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) $f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) $f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
end
end
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(a, A_mul_Bt(Δ, data(b))) @back(a, A_mul_Bt(Δ, data(b)))
@back(b, At_mul_B(data(a), Δ)) @back(b, At_mul_B(data(a), Δ))
end end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))')
@back(b, *(data(a), Δ))
end
# Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W)
W.grad .+= Δ .* data(x).'
else
back(W, A_mul_Bt(Δ, data(x)))
end
@back(x, At_mul_B(data(W), Δ))
end
# NNlib # NNlib
import NNlib: softmax, ∇softmax import NNlib: softmax, ∇softmax

View File

@ -35,3 +35,5 @@ function params(m)
prefor(p -> p isa TrackedArray && push!(ps, p), m) prefor(p -> p isa TrackedArray && push!(ps, p), m)
return ps return ps
end end
params(m...) = params(m)

17
test/optimise.jl Normal file
View File

@ -0,0 +1,17 @@
using Flux.Optimise
using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])
for t=1:10^5
l = loss(rand(10))
back!(l)
opt()
end
@test Flux.mse(w, w) < 0.01
end
end

View File

@ -5,5 +5,6 @@ using Flux, Base.Test
include("utils.jl") include("utils.jl")
include("tracker.jl") include("tracker.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")
include("optimise.jl")
end end

View File

@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), 3)
@ -32,23 +34,12 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest(rand(5)) do x @test gradtest(rand(5)) do x
y = x.^2 y = x.^2
2y + x 2y + x
end end
for T in [Float32, Float64]
@test isa(param(T(1)), TrackedArray{T, 0})
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
end
# TODO: do we wand this behaviour ??
F = typeof(AbstractFloat(1))
for T in [Int32, Int64]
@test isa(param(T(1)), TrackedArray{F, 0})
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
end
end #testset end #testset