Merge branch 'master' into batchnorm

This commit is contained in:
Mike J Innes 2017-12-08 19:29:49 +00:00
commit 6f997e798a
21 changed files with 288 additions and 106 deletions

View File

@ -151,3 +151,13 @@ m = Chain(x -> x^2, x -> x+1)
m(5) # => 26
```
## Layer helpers
Flux provides a set of helpers for custom layers, which you can enable by calling
```julia
Flux.treelike(Affine)
```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -36,6 +36,8 @@ swish
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs
Flux.testmode!
BatchNorm
Dropout
LayerNorm
```

View File

@ -58,8 +58,5 @@ All optimisers return a function that, when called, will update the parameters p
SGD
Momentum
Nesterov
RMSProp
ADAM
ADAGrad
ADADelta
```

View File

@ -7,12 +7,13 @@ module Flux
using Juno, Requires
using Lazy: @forward
export BatchNorm, Chain, Dense, RNN, LSTM, Dropout,
SGD, ADAM, Momentum, Nesterov,
export Chain, Dense, RNN, LSTM,
Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves
using NNlib
export σ, relu, leakyrelu, elu, swish, softmax
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
include("tracker/Tracker.jl")
using .Tracker
@ -22,7 +23,7 @@ using .Optimise
include("utils.jl")
include("onehot.jl")
include("tree.jl")
include("treelike.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
@ -31,4 +32,6 @@ include("layers/normalisation.jl")
include("data/Data.jl")
include("batches/Batches.jl")
end # module

7
src/batches/Batches.jl Normal file
View File

@ -0,0 +1,7 @@
module Batches
import ..Flux
include("batch.jl")
end

8
src/batches/batch.jl Normal file
View File

@ -0,0 +1,8 @@
struct Batch{T,A,M}
data::A
mask::M
end
Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask)
Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))

View File

@ -33,7 +33,7 @@ function rawdict()
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
end
validword(s) = ismatch(r"^[\w-\.]+$", s)
validword(s) = ismatch(r"^[\w\-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict())

View File

@ -78,3 +78,32 @@ function Base.show(io::IO, l::Dense)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable
vectors `α` and `β`:
y = α .* x .+ β
The input `x` must be a array where `size(x, 1) == in`.
"""
struct Diagonal{T}
α::T
β::T
end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in)))
treelike(Diagonal)
function (a::Diagonal)(x)
α, β = a.α, a.β
α.*x .+ β
end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end

View File

@ -44,6 +44,29 @@ end
_testmode!(a::Dropout, test) = (a.active = !test)
"""
LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
each input before applying a per-neuron gain/bias.
"""
struct LayerNorm{T}
diag::Diagonal{T}
end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
treelike(LayerNorm)
(a::LayerNorm)(x) = a.diag(normalise(x))
function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", length(l.diag.α), ")")
end
"""
BatchNorm(dims...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
@ -65,8 +88,6 @@ julia> m = Chain(
BatchNorm(10),
softmax)
Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax)
julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate
```
"""
mutable struct BatchNorm{F,V,N}

View File

@ -1,14 +1,27 @@
using NNlib: log_fast
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) =
-sum(y .* log.()) / size(y, 2)
-sum(y .* log_fast.()) / size(y, 2)
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
end
"""
normalise(x::AbstractVecOrMat)
Normalise each column of `x` to mean 0 and standard deviation 1.
"""
function normalise(x::AbstractVecOrMat)
μ′ = mean(x, 1)
σ = std(x, 1, mean = μ′)
return (x .- μ′) ./ σ
end

View File

@ -1,3 +1,5 @@
import Base: *
struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
height::Int
@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
@ -40,10 +42,22 @@ function onehot(l, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
function onehot(l, labels, unk)
i = findfirst(labels, l)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))]
argmax(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
# Ambiguity hack
a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b))
a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b))

View File

@ -1,7 +1,7 @@
module Optimise
export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T}
x::T

View File

@ -1,5 +1,7 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
@ -10,64 +12,73 @@ function optimiser(ps, fs...)
end
"""
SGD(params, η = 1; decay = 0)
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser. For each parameter `p` and its
gradient `δp`, this runs `p -= η*δp`.
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports decayed learning rate decay if the `decay` argument is provided.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η))
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, ρ, decay = 0)
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with momentum `ρ` and optional learning rate decay.
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, ρ; decay = 0) =
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, ρ, decay = 0)
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with Nesterov momentum `ρ` and optional learning rate decay.
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, ρ; decay = 0) =
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0)
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0)
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0)
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

View File

@ -1,74 +1,97 @@
function descent(p::Param, η::Real)
function ()
p.x .-= p.Δ .* η
p.Δ .= 0
@. p.x -= η * p.Δ
@. p.Δ = 0
end
end
function momentum(p::Param, ρ::Real)
mo = zeros(p.x)
() -> p.Δ .= mo .= ρ .* mo .+ p.Δ
end
function nesterov(p::Param, ρ::Real)
mo = zeros(p.x)
function momentum(p::Param, ρ, η)
v = zeros(p.x)
function ()
mo .= ρ .* mo .+ p.Δ
p.Δ .= ρ .* mo .+ p.Δ
@. v = ρ * v - η * p.Δ
@. p.Δ = -v
end
end
function clip(p::Param, thresh::Real)
() -> clamp!(p.Δ, -thresh, thresh)
end
function weightdecay(p::Param, γ::Real)
() -> p.Δ .+= γ .* p.x
end
function invdecay(p::Param, γ::Real)
n = 0
# Ref. https://arxiv.org/pdf/1212.0901.pdf
function nesterov(p::Param, ρ, η)
v = zeros(p.x)
function ()
p.Δ .*= 1 / (1 + γ * n)
n += 1
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ
@. p.Δ = -d
end
end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
acc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
@. p.Δ /= acc * η
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
end
end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
function ()
@. acc += p.Δ ^ 2
@. p.Δ /= acc * η
@. acc += p.Δ^2
@. p.Δ *= η / acc
end
end
function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
Δacc = zeros(p.x) .+ ϵ
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
Δacc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
@. p.Δ *= Δacc / acc
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2
end
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end
end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. p.Δ = (1 - β2p) / (1 - β1p) * mt / vt * η
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
end

View File

@ -1,8 +1,8 @@
using Juno
using Flux.Tracker: back!
tocb(f) = f
tocb(fs::AbstractVector) = () -> foreach(call, fs)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
"""
train!(loss, data, opt; cb = () -> ())
@ -11,10 +11,11 @@ For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt` and the callback `cb`
(i.e. `opt()` and `cb()`).
Multiple callbacks can be passed to `cb` as an array.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, data, opt; cb = () -> ())
cb = tocb(cb)
cb = runall(cb)
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(l.data[]) && error("Loss is Inf")

View File

@ -1,6 +1,6 @@
module Tracker
export TrackedArray, param, back!
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
data(x) = x
istracked(x) = false
@ -38,7 +38,9 @@ TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs))
isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true
@ -56,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
# TODO decide if keeping both data and value. The problem is TrackedScalar
value(x) = x
value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[]
@ -67,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")

View File

@ -1,5 +1,3 @@
import Base: *
toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
@ -60,25 +58,55 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
# BLAS
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
for f in :[*, Ac_mul_B].args
@eval begin
import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b))
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
$f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
end
end
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(a, A_mul_Bt(Δ, data(b)))
@back(b, At_mul_B(data(a), Δ))
end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))')
@back(b, *(data(a), Δ))
end
# Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W)
W.grad .+= Δ .* data(x).'
else
back(W, A_mul_Bt(Δ, data(x)))
end
@back(x, At_mul_B(data(W), Δ))
end
# NNlib
import NNlib: softmax, ∇softmax

View File

@ -35,3 +35,5 @@ function params(m)
prefor(p -> p isa TrackedArray && push!(ps, p), m)
return ps
end
params(m...) = params(m)

17
test/optimise.jl Normal file
View File

@ -0,0 +1,17 @@
using Flux.Optimise
using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])
for t=1:10^5
l = loss(rand(10))
back!(l)
opt()
end
@test Flux.mse(w, w) < 0.01
end
end

View File

@ -5,5 +5,6 @@ using Flux, Base.Test
include("utils.jl")
include("tracker.jl")
include("layers/normalisation.jl")
include("optimise.jl")
end

View File

@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@ -32,23 +34,12 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest(rand(5)) do x
y = x.^2
2y + x
end
for T in [Float32, Float64]
@test isa(param(T(1)), TrackedArray{T, 0})
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
end
# TODO: do we wand this behaviour ??
F = typeof(AbstractFloat(1))
for T in [Int32, Int64]
@test isa(param(T(1)), TrackedArray{F, 0})
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
end
end #testset