Merge branch 'master' into cat-fix

This commit is contained in:
GenaBitu 2018-01-16 11:01:31 +01:00
commit bc8a32bc56
No known key found for this signature in database
GPG Key ID: 6E647E317A9DD426
27 changed files with 544 additions and 107 deletions

View File

@ -1,6 +1,6 @@
# Флукс
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![Join the chat at https://gitter.im/FluxML](https://badges.gitter.im/FluxML/Lobby.svg)](https://gitter.im/FluxML/Lobby) [Slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866)
[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![Join the chat at https://gitter.im/FluxML](https://badges.gitter.im/FluxML/Lobby.svg)](https://gitter.im/FluxML/Lobby) [Slack](https://slackinvite.julialang.org/)
Flux is a refreshing approach to machine learning. It provides lightweight abstractions on top of Julia's native GPU and AD support, while remaining fully hackable (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)).

View File

@ -3,5 +3,6 @@ DataFlow 0.2.1
Juno
MacroTools 0.3.3
NNlib
ForwardDiff
ForwardDiff 0.5.0
Requires
Adapt

View File

@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks.
```@docs
Chain
Dense
Conv2D
```
## Recurrent Layers
@ -37,6 +38,7 @@ These layers don't affect the structure of the network but may improve training
```@docs
Flux.testmode!
BatchNorm
Dropout
LayerNorm
```

View File

@ -7,12 +7,14 @@ module Flux
using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
SGD, ADAM, Momentum, Nesterov,
export Chain, Dense, RNN, LSTM, GRU, Conv2D,
Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves
using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl")
using .Tracker
@ -26,6 +28,7 @@ include("treelike.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalisation.jl")

View File

@ -23,17 +23,17 @@ end
function symbols()
load()
Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")),
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
"\n", keep = false))
end
function rawdict()
load()
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
end
validword(s) = ismatch(r"^[\w-\.]+$", s)
validword(s) = ismatch(r"^[\w\-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict())

View File

@ -63,8 +63,10 @@ struct Dense{F,S,T}
b::T
end
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
Dense(σ, param(init(out, in)), param(init(out)))
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
return Dense(σ, param(initW(out, in)), param(initb(out)))
end
treelike(Dense)

33
src/layers/conv.jl Normal file
View File

@ -0,0 +1,33 @@
"""
Conv2D(size, in=>out)
Conv2d(size, in=>out, relu)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in HWCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct Conv2D{F,A}
σ::F
weight::A
stride::Int
pad::Int
end
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = initn, stride = 1, pad = 0) =
Conv2D(σ, param(init(k..., ch...)), stride, pad)
Flux.treelike(Conv2D)
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad))
function Base.show(io::IO, l::Conv2D)
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")")
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

View File

@ -2,8 +2,8 @@
testmode!(m)
testmode!(m, false)
Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to
training mode with `false`).
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
(or back to training mode with `false`).
"""
function testmode!(m, val::Bool=true)
prefor(x -> _testmode!(x, val), m)
@ -45,6 +45,7 @@ end
_testmode!(a::Dropout, test) = (a.active = !test)
"""
LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
@ -65,3 +66,77 @@ treelike(LayerNorm)
function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", length(l.diag.α), ")")
end
"""
BatchNorm(dims...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
Batch Normalization Layer for [`Dense`](@ref) layer.
See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf)
In the example of MNIST,
in order to normalize the input of other layer,
put the `BatchNorm` layer before activation function.
```julia
m = Chain(
Dense(28^2, 64),
BatchNorm(64, λ = relu),
Dense(64, 10),
BatchNorm(10),
softmax)
```
"""
mutable struct BatchNorm{F,V,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ # moving mean
σ # moving std
ϵ::N
momentum::N
active::Bool
end
BatchNorm(dims::Integer...; λ = identity,
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
function (BN::BatchNorm)(x)
λ, γ, β = BN.λ, BN.γ, BN.β
if !BN.active
μ = BN.μ
σ = BN.σ
else
T = eltype(x)
ϵ = T(BN.ϵ)
m = size(x, 2) # batch size
μ = mean(x, 2)
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
# update moving mean/std
mtm = T(BN.momentum)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
end
λ.(γ .* ((x .- μ) ./ σ) .+ β)
end
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

View File

@ -79,8 +79,8 @@ struct RNNCell{D,V}
h::V
end
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) =
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
function (m::RNNCell)(h, x)
h = m.d(combine(x, h))
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
h::V; c::V
end
function LSTMCell(in, out; init = initn)
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]...,
Dense(in+out, out, tanh, init = init),
param(init(out)), param(init(out)))
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
Dense(in+out, out, tanh, initW = initW, initb = initb),
param(initW(out)), param(initW(out)))
cell.forget.b.data .= 1
return cell
end
@ -150,3 +150,49 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
h::V
end
function GRUCell(in, out)
cell = GRUCell(Dense(in+out, out, σ),
Dense(in+out, out, σ),
Dense(in+out, out, tanh),
param(initn(out)))
return cell
end
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*
return h, h
end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
"""
GRU(in::Integer, out::Integer, σ = tanh)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

View File

@ -4,8 +4,9 @@ using NNlib: log_fast
mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) =
-sum(y .* log_fast.()) / size(y, 2)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* log_fast.() .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)

View File

@ -18,7 +18,9 @@ end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
@ -26,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import NNlib.adapt
import Adapt.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))

View File

@ -1,7 +1,7 @@
module Optimise
export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T}
x::T

View File

@ -1,5 +1,7 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
@ -10,34 +12,34 @@ function optimiser(ps, fs...)
end
"""
SGD(params, η = 1; decay = 0)
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser. For each parameter `p` and its
gradient `δp`, this runs `p -= η*δp`.
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports decayed learning rate decay if the `decay` argument is provided.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 1; decay = 0) =
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, ρ, decay = 0)
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with momentum `ρ` and optional learning rate decay.
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, ρ; decay = 0) =
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, ρ, decay = 0)
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with Nesterov momentum `ρ` and optional learning rate decay.
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, ρ; decay = 0) =
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0)
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
@ -47,7 +49,7 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
@ -55,19 +57,28 @@ ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0)
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) =
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0)
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

View File

@ -1,44 +1,33 @@
function descent(p::Param, η::Real)
function ()
p.x .-= p.Δ .* η
p.Δ .= 0
@. p.x -= η * p.Δ
@. p.Δ = 0
end
end
function momentum(p::Param, ρ::Real)
mo = zeros(p.x)
() -> p.Δ .= mo .= ρ .* mo .+ p.Δ
end
function nesterov(p::Param, ρ::Real)
mo = zeros(p.x)
function momentum(p::Param, ρ, η)
v = zeros(p.x)
function ()
mo .= ρ .* mo .+ p.Δ
p.Δ .= ρ .* mo .+ p.Δ
@. v = ρ * v - η * p.Δ
@. p.Δ = -v
end
end
function clip(p::Param, thresh::Real)
() -> clamp!(p.Δ, -thresh, thresh)
end
function weightdecay(p::Param, γ::Real)
() -> p.Δ .+= γ .* p.x
end
function invdecay(p::Param, γ::Real)
n = 0
# Ref. https://arxiv.org/pdf/1212.0901.pdf
function nesterov(p::Param, ρ, η)
v = zeros(p.x)
function ()
p.Δ .*= 1 / (1 + γ * n)
n += 1
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ
@. p.Δ = -d
end
end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
acc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / acc
@. p.Δ *= η / (acc + ϵ)
end
end
@ -50,25 +39,59 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
end
end
function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
Δacc = zeros(p.x) .+ ϵ
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
Δacc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= Δacc / acc
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end
end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (1 - β2p) / (1 - β1p) * mt / vt * η
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
end

View File

@ -1,15 +1,24 @@
using Juno
using Flux.Tracker: back!
using Flux.Tracker: back!, value
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
"""
train!(loss, data, opt; cb = () -> ())
train!(loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt` and the callback `cb`
(i.e. `opt()` and `cb()`).
backpropagation and calls the optimizer `opt`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(l.data[]) && error("Loss is Inf")
isnan(l.data[]) && error("Loss is NaN")
isinf(value(l)) && error("Loss is Inf")
isnan(value(l)) && error("Loss is NaN")
back!(l)
opt()
cb()
cb() == :stop && break
end
end

View File

@ -58,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
# TODO decide if keeping both data and value. The problem is TrackedScalar
value(x) = x
value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[]
@ -69,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")
@ -91,7 +93,7 @@ include("back.jl")
include("lib.jl")
include("numeric.jl")
import NNlib.adapt
import Adapt.adapt
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))

View File

@ -12,16 +12,17 @@ function scan(x::TrackedArray)
return
end
back(c::Call, Δ) = back(c.func, Δ, c.args...)
back(::Call{Void}, Δ) = nothing
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back(x::TrackedArray, Δ)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
ref == 0 && back(x.f, x.grad)
ref == 0 && back_(x.f, x.data, x.grad)
else
ref == 0 && back(x.f, Δ)
ref == 0 && back_(x.f, x.data, Δ)
end
return
end
@ -35,6 +36,9 @@ end
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
function back!(x::TrackedArray, Δ)
scan(x)
back(x, Δ)

View File

@ -48,6 +48,12 @@ function back(::typeof(vcat), Δ, xs...)
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
TrackedArray(Call(reshape, xs, dims...))
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
# Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
@ -62,6 +68,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
function back(::typeof(dot), Δ, xs, ys)
@back(xs, Δ.*ys)
@back(ys, Δ.*xs)
end
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
@ -74,7 +89,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
# BLAS
for f in :[*, Ac_mul_B].args
for f in :[*, Ac_mul_B, A_mul_Bc].args
@eval begin
import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
@ -98,7 +113,12 @@ end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))')
@back(b, *(data(a), Δ))
@back(b, data(a)*Δ)
end
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, Δ * data(b))
@back(b, At_mul_B(data(a), Δ)')
end
# Fast path for matrix-vector
@ -113,12 +133,36 @@ end
# NNlib
import NNlib: softmax, ∇softmax
using NNlib
import NNlib: softmax, ∇softmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
end
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
TrackedArray(Call(_pool, x, window, padding, mode))
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
# Broadcasting
using ForwardDiff: Dual, partials
@ -134,9 +178,11 @@ dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
# Works around a 0.6 type inference issue
b = Broadcasted(broadcast(f, dargs...))
b = Broadcasted(out)
TrackedArray(Call(b, args...), b())
end

View File

@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
return grads
end
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))

View File

@ -1,8 +1,8 @@
# Arrays
initn(dims...) = randn(dims...)/100
flatten(xs) = reshape(xs, size(xs, 1), :)
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
@ -93,13 +93,14 @@ but if you'd like to disable the execution on the leading edge, pass
function throttle(f, timeout; leading=true, trailing=false)
cooldown = true
later = nothing
result = nothing
function throttled(args...; kwargs...)
yield()
if cooldown
if leading
f(args...; kwargs...)
result = f(args...; kwargs...)
else
later = () -> f(args...; kwargs...)
end
@ -114,9 +115,28 @@ function throttle(f, timeout; leading=true, trailing=false)
cooldown = true
end
elseif trailing
later = () -> f(args...; kwargs...)
later = () -> (result = f(args...; kwargs...))
end
nothing
return result
end
end
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(n,k)
for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator
end
J'
end

View File

@ -1,3 +1,8 @@
using Flux.Data
using Base.Test
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
@test length(CMUDict.phones()) == 39
@test length(CMUDict.symbols()) == 84

View File

@ -26,3 +26,55 @@ using Flux: testmode!
y = m(x)
@test count(a->a == 0, y) == 0
end
@testset "BatchNorm" begin
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]')
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
# initial m.σ is 1
# initial m.μ is 0
@test m.active
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
m(x)
# julia> x
# 2×3 Array{Float64,2}:
# 1.0 3.0 5.0
# 2.0 4.0 6.0
#
# μ of batch will be
# (1. + 3. + 5.) / 3 = 3
# (2. + 4. + 6.) / 3 = 4
#
# ∴ update rule with momentum:
# .1 * 3 + 0 = .3
# .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1)
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 2×1 Array{Float64,2}:
# 1.14495
# 1.14495
@test m.σ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m)
@test !m.active
x = m(x).data
@test x[1] (1 - 0.3) / 1.1449489742783179
end
# with activation function
let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]')
@test m.active
m(x)
testmode!(m)
@test !m.active
x = m(x).data
@test x[1] σ((1 - 0.3) / 1.1449489742783179)
end
end

26
test/layers/stateless.jl Normal file
View File

@ -0,0 +1,26 @@
using Flux: onehotbatch, mse, crossentropy
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9]
@testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss
end
@testset "weighted_crossentropy" begin
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199
end
end

29
test/optimise.jl Normal file
View File

@ -0,0 +1,29 @@
using Flux.Optimise
using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])
for t=1:10^5
l = loss(rand(10))
back!(l)
opt()
end
@test Flux.mse(w, w) < 0.01
end
end
@testset "Training Loop" begin
i = 0
l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100),
()->(),
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
@test 3 < i < 50
end

View File

@ -5,5 +5,8 @@ using Flux, Base.Test
include("utils.jl")
include("tracker.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")
include("data.jl")
end

View File

@ -1,5 +1,6 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: gradcheck
using NNlib
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@ -10,6 +11,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@ -37,9 +39,17 @@ end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(rand(5)) do x
y = x.^2
2y + x
end
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
@test (param([1,2,3]) .< 2) == [true, false, false]
end #testset

View File

@ -1,4 +1,4 @@
using Flux: throttle
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
@testset "Throttle" begin
@testset "default behaviour" begin
@ -47,3 +47,35 @@ using Flux: throttle
@test a == [1, 3]
end
end
@testset "Jacobian" begin
A = param(randn(2,2))
x = randn(2)
m(x) = A*x
y = m(x)
J = jacobian(m,x)
@test J A.data
end
@testset "Initialization" begin
# Set random seed so that these tests don't fail randomly
srand(0)
# initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10)
@test std(v) > 0.9*1e-2
@test std(v) < 1.1*1e-2
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = glorot_uniform(n_in, n_out)
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
v = glorot_normal(n_in, n_out)
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
end
end