Merge branch 'master' into jacobian

This commit is contained in:
Mike J Innes 2017-12-13 17:06:23 +00:00
commit 95d1287455
20 changed files with 349 additions and 85 deletions

View File

@ -37,6 +37,7 @@ These layers don't affect the structure of the network but may improve training
```@docs ```@docs
Flux.testmode! Flux.testmode!
BatchNorm
Dropout Dropout
LayerNorm LayerNorm
``` ```

View File

@ -7,12 +7,13 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, export Chain, Dense, RNN, LSTM,
SGD, ADAM, Momentum, Nesterov, Dropout, LayerNorm, BatchNorm,
param, params, mapleaves, jacobian SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves
using NNlib using NNlib
export σ, relu, leakyrelu, elu, swish, softmax export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker

View File

@ -33,7 +33,7 @@ function rawdict()
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
end end
validword(s) = ismatch(r"^[\w-\.]+$", s) validword(s) = ismatch(r"^[\w\-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict()) cmudict() = filter((s, ps) -> validword(s), rawdict())

View File

@ -63,8 +63,10 @@ struct Dense{F,S,T}
b::T b::T
end end
Dense(in::Integer, out::Integer, σ = identity; init = initn) = function Dense(in::Integer, out::Integer, σ = identity;
Dense(σ, param(init(out, in)), param(init(out))) initW = glorot_uniform, initb = zeros)
return Dense(σ, param(initW(out, in)), param(initb(out)))
end
treelike(Dense) treelike(Dense)

View File

@ -2,8 +2,8 @@
testmode!(m) testmode!(m)
testmode!(m, false) testmode!(m, false)
Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
training mode with `false`). (or back to training mode with `false`).
""" """
function testmode!(m, val::Bool=true) function testmode!(m, val::Bool=true)
prefor(x -> _testmode!(x, val), m) prefor(x -> _testmode!(x, val), m)
@ -45,6 +45,7 @@ end
_testmode!(a::Dropout, test) = (a.active = !test) _testmode!(a::Dropout, test) = (a.active = !test)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
@ -65,3 +66,77 @@ treelike(LayerNorm)
function Base.show(io::IO, l::LayerNorm) function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", length(l.diag.α), ")") print(io, "LayerNorm(", length(l.diag.α), ")")
end end
"""
BatchNorm(dims...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
Batch Normalization Layer for [`Dense`](@ref) layer.
See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf)
In the example of MNIST,
in order to normalize the input of other layer,
put the `BatchNorm` layer before activation function.
```julia
m = Chain(
Dense(28^2, 64),
BatchNorm(64, λ = relu),
Dense(64, 10),
BatchNorm(10),
softmax)
```
"""
mutable struct BatchNorm{F,V,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ # moving mean
σ # moving std
ϵ::N
momentum::N
active::Bool
end
BatchNorm(dims::Integer...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
function (BN::BatchNorm)(x)
λ, γ, β = BN.λ, BN.γ, BN.β
if !BN.active
μ = BN.μ
σ = BN.σ
else
T = eltype(x)
ϵ = T(BN.ϵ)
m = size(x, 2) # batch size
μ = mean(x, 2)
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
# update moving mean/std
mtm = T(BN.momentum)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
end
λ.(γ .* ((x .- μ) ./ σ) .+ β)
end
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

View File

@ -79,8 +79,8 @@ struct RNNCell{D,V}
h::V h::V
end end
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) = RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
RNNCell(Dense(in+out, out, σ, init = init), param(init(out))) RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
h = m.d(combine(x, h)) h = m.d(combine(x, h))
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
h::V; c::V h::V; c::V
end end
function LSTMCell(in, out; init = initn) function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]..., cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
Dense(in+out, out, tanh, init = init), Dense(in+out, out, tanh, initW = initW, initb = initb),
param(init(out)), param(init(out))) param(initW(out)), param(initW(out)))
cell.forget.b.data .= 1 cell.forget.b.data .= 1
return cell return cell
end end

View File

@ -1,15 +1,18 @@
using NNlib: log_fast
# Cost functions # Cost functions
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) = function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.()) / size(y, 2) return -sum(y .* log_fast.() .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1) logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log.(sum(exp.(logŷ), 1)) ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2) -sum(y .* ypred) / size(y, 2)
end end

View File

@ -42,7 +42,14 @@ function onehot(l, labels)
OneHotVector(i, length(labels)) OneHotVector(i, length(labels))
end end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) function onehot(l, labels, unk)
i = findfirst(labels, l)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))] labels[findfirst(y, maximum(y))]

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export update!, params, train!, export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T} struct Param{T}
x::T x::T

View File

@ -1,5 +1,7 @@
call(f, xs...) = f(xs...) call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
function optimiser(ps, fs...) function optimiser(ps, fs...)
ps = [Param(p) for p in ps] ps = [Param(p) for p in ps]
fs = map(ps) do p fs = map(ps) do p
@ -10,64 +12,73 @@ function optimiser(ps, fs...)
end end
""" """
SGD(params, η = 1; decay = 0) SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser. For each parameter `p` and its Classic gradient descent optimiser with learning rate `η`.
gradient `δp`, this runs `p -= η*δp`. For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports decayed learning rate decay if the `decay` argument is provided. Supports inverse decaying learning rate if the `decay` argument is provided.
""" """
SGD(ps, η = 1; decay = 0) = SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η)) optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
""" """
Momentum(params, ρ, decay = 0) Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with momentum `ρ` and optional learning rate decay. SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
""" """
Momentum(ps, ρ; decay = 0) = Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
""" """
Nesterov(params, ρ, decay = 0) Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with Nesterov momentum `ρ` and optional learning rate decay. SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
""" """
Nesterov(ps, ρ; decay = 0) = Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
""" """
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0) RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks. choice for recurrent networks.
""" """
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
""" """
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0) ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning. Parameters don't need tuning.
""" """
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) = ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning. tuning.
""" """
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

View File

@ -1,74 +1,97 @@
function descent(p::Param, η::Real) function descent(p::Param, η::Real)
function () function ()
p.x .-= p.Δ .* η @. p.x -= η * p.Δ
p.Δ .= 0 @. p.Δ = 0
end end
end end
function momentum(p::Param, ρ::Real) function momentum(p::Param, ρ, η)
mo = zeros(p.x) v = zeros(p.x)
() -> p.Δ .= mo .= ρ .* mo .+ p.Δ
end
function nesterov(p::Param, ρ::Real)
mo = zeros(p.x)
function () function ()
mo .= ρ .* mo .+ p.Δ @. v = ρ * v - η * p.Δ
p.Δ .= ρ .* mo .+ p.Δ @. p.Δ = -v
end end
end end
function clip(p::Param, thresh::Real) # Ref. https://arxiv.org/pdf/1212.0901.pdf
() -> clamp!(p.Δ, -thresh, thresh) function nesterov(p::Param, ρ, η)
end v = zeros(p.x)
function weightdecay(p::Param, γ::Real)
() -> p.Δ .+= γ .* p.x
end
function invdecay(p::Param, γ::Real)
n = 0
function () function ()
p.Δ .*= 1 / (1 + γ * n) d = @. ρ^2 * v - (1+ρ) * η * p.Δ
n += 1 @. v = ρ*v - η*p.Δ
@. p.Δ = -d
end end
end end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x)
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 @. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / acc @. p.Δ *= η / (acc + ϵ)
end end
end end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x) .+ ϵ
function () function ()
@. acc += p.Δ ^ 2 @. acc += p.Δ^2
@. p.Δ *= η / acc @. p.Δ *= η / acc
end end
end end
function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8) function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x)
Δacc = zeros(p.x) .+ ϵ Δacc = zeros(p.x)
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 @. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= Δacc / acc @. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2 @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end end
end end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ vt = zeros(p.x)
β1p, β2p = β1, β2 β1p, β2p = β1, β2
function () function ()
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 @. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (1 - β2p) / (1 - β1p) * mt / vt * η @. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
β1p *= β1 β1p *= β1
β2p *= β2 β2p *= β2
end end
end end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
end

View File

@ -40,7 +40,7 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
isleaf(x::TrackedArray) = x.f == Call(nothing) isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(AbstractFloat.(xs)) param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(xs::Real) = param(fill(xs)) param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true istracked(x::TrackedArray) = true
@ -58,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
# TODO decide if keeping both data and value. The problem is TrackedScalar
value(x) = x value(x) = x
value(x::TrackedArray) = data(x) value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[] value(x::TrackedScalar) = data(x)[]
@ -69,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
Base.isless(x::TrackedScalar, y) = isless(value(x), y) Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y)) Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}") print(io, "TrackedArray{…,$A}")

View File

@ -58,6 +58,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
function back(::typeof(dot), Δ, xs, ys)
@back(xs, Δ.*ys)
@back(ys, Δ.*xs)
end
# Hacks to get std working # Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) = Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
@ -70,7 +79,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
# BLAS # BLAS
for f in :[*, Ac_mul_B].args for f in :[*, Ac_mul_B, A_mul_Bc].args
@eval begin @eval begin
import Base.$f import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
@ -94,7 +103,12 @@ end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))') @back(a, A_mul_Bt(Δ, data(b))')
@back(b, *(data(a), Δ)) @back(b, data(a)*Δ)
end
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, Δ * data(b))
@back(b, At_mul_B(data(a), Δ)')
end end
# Fast path for matrix-vector # Fast path for matrix-vector

View File

@ -1,6 +1,8 @@
# Arrays # Arrays
initn(dims...) = randn(dims...)/100 initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
flatten(xs) = reshape(xs, size(xs, 1), :) flatten(xs) = reshape(xs, size(xs, 1), :)

View File

@ -26,3 +26,55 @@ using Flux: testmode!
y = m(x) y = m(x)
@test count(a->a == 0, y) == 0 @test count(a->a == 0, y) == 0
end end
@testset "BatchNorm" begin
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]')
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
# initial m.σ is 1
# initial m.μ is 0
@test m.active
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
m(x)
# julia> x
# 2×3 Array{Float64,2}:
# 1.0 3.0 5.0
# 2.0 4.0 6.0
#
# μ of batch will be
# (1. + 3. + 5.) / 3 = 3
# (2. + 4. + 6.) / 3 = 4
#
# ∴ update rule with momentum:
# .1 * 3 + 0 = .3
# .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1)
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 2×1 Array{Float64,2}:
# 1.14495
# 1.14495
@test m.σ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m)
@test !m.active
x = m(x).data
@test x[1] (1 - 0.3) / 1.1449489742783179
end
# with activation function
let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]')
@test m.active
m(x)
testmode!(m)
@test !m.active
x = m(x).data
@test x[1] σ((1 - 0.3) / 1.1449489742783179)
end
end

26
test/layers/stateless.jl Normal file
View File

@ -0,0 +1,26 @@
using Flux: onehotbatch, mse, crossentropy
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9]
@testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss
end
@testset "weighted_crossentropy" begin
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199
end
end

17
test/optimise.jl Normal file
View File

@ -0,0 +1,17 @@
using Flux.Optimise
using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])
for t=1:10^5
l = loss(rand(10))
back!(l)
opt()
end
@test Flux.mse(w, w) < 0.01
end
end

View File

@ -5,5 +5,7 @@ using Flux, Base.Test
include("utils.jl") include("utils.jl")
include("tracker.jl") include("tracker.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")
end end

View File

@ -10,6 +10,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@ -37,6 +38,8 @@ end
@test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(rand(5)) do x @test gradtest(rand(5)) do x
y = x.^2 y = x.^2
2y + x 2y + x

View File

@ -1,4 +1,4 @@
using Flux: throttle using Flux: throttle, initn, glorot_uniform, glorot_normal
@testset "Throttle" begin @testset "Throttle" begin
@testset "default behaviour" begin @testset "default behaviour" begin
@ -56,3 +56,26 @@ end
J = jacobian(m,x) J = jacobian(m,x)
@test J A.data @test J A.data
end end
@testset "Initialization" begin
# Set random seed so that these tests don't fail randomly
srand(0)
# initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10)
@test std(v) > 0.9*1e-2
@test std(v) < 1.1*1e-2
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = glorot_uniform(n_in, n_out)
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
v = glorot_normal(n_in, n_out)
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
end
end