Merge branch 'master' into cat-fix
This commit is contained in:
commit
bc8a32bc56
@ -1,6 +1,6 @@
|
|||||||
# Флукс
|
# Флукс
|
||||||
|
|
||||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://gitter.im/FluxML/Lobby) [Slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866)
|
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://gitter.im/FluxML/Lobby) [Slack](https://slackinvite.julialang.org/)
|
||||||
|
|
||||||
Flux is a refreshing approach to machine learning. It provides lightweight abstractions on top of Julia's native GPU and AD support, while remaining fully hackable (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)).
|
Flux is a refreshing approach to machine learning. It provides lightweight abstractions on top of Julia's native GPU and AD support, while remaining fully hackable (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)).
|
||||||
|
|
||||||
|
3
REQUIRE
3
REQUIRE
@ -3,5 +3,6 @@ DataFlow 0.2.1
|
|||||||
Juno
|
Juno
|
||||||
MacroTools 0.3.3
|
MacroTools 0.3.3
|
||||||
NNlib
|
NNlib
|
||||||
ForwardDiff
|
ForwardDiff 0.5.0
|
||||||
Requires
|
Requires
|
||||||
|
Adapt
|
||||||
|
@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks.
|
|||||||
```@docs
|
```@docs
|
||||||
Chain
|
Chain
|
||||||
Dense
|
Dense
|
||||||
|
Conv2D
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recurrent Layers
|
## Recurrent Layers
|
||||||
@ -37,6 +38,7 @@ These layers don't affect the structure of the network but may improve training
|
|||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
Flux.testmode!
|
Flux.testmode!
|
||||||
|
BatchNorm
|
||||||
Dropout
|
Dropout
|
||||||
LayerNorm
|
LayerNorm
|
||||||
```
|
```
|
||||||
|
@ -7,12 +7,14 @@ module Flux
|
|||||||
using Juno, Requires
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
|
export Chain, Dense, RNN, LSTM, GRU, Conv2D,
|
||||||
SGD, ADAM, Momentum, Nesterov,
|
Dropout, LayerNorm, BatchNorm,
|
||||||
|
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
|
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
|
||||||
|
conv2d, maxpool2d, avgpool2d
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
@ -26,6 +28,7 @@ include("treelike.jl")
|
|||||||
|
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
|
include("layers/conv.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
|
|
||||||
|
@ -23,17 +23,17 @@ end
|
|||||||
|
|
||||||
function symbols()
|
function symbols()
|
||||||
load()
|
load()
|
||||||
Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")),
|
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
|
||||||
"\n", keep = false))
|
"\n", keep = false))
|
||||||
end
|
end
|
||||||
|
|
||||||
function rawdict()
|
function rawdict()
|
||||||
load()
|
load()
|
||||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||||
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
|
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
|
||||||
end
|
end
|
||||||
|
|
||||||
validword(s) = ismatch(r"^[\w-\.]+$", s)
|
validword(s) = ismatch(r"^[\w\-\.]+$", s)
|
||||||
|
|
||||||
cmudict() = filter((s, ps) -> validword(s), rawdict())
|
cmudict() = filter((s, ps) -> validword(s), rawdict())
|
||||||
|
|
||||||
|
@ -63,8 +63,10 @@ struct Dense{F,S,T}
|
|||||||
b::T
|
b::T
|
||||||
end
|
end
|
||||||
|
|
||||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
function Dense(in::Integer, out::Integer, σ = identity;
|
||||||
Dense(σ, param(init(out, in)), param(init(out)))
|
initW = glorot_uniform, initb = zeros)
|
||||||
|
return Dense(σ, param(initW(out, in)), param(initb(out)))
|
||||||
|
end
|
||||||
|
|
||||||
treelike(Dense)
|
treelike(Dense)
|
||||||
|
|
||||||
|
33
src/layers/conv.jl
Normal file
33
src/layers/conv.jl
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
Conv2D(size, in=>out)
|
||||||
|
Conv2d(size, in=>out, relu)
|
||||||
|
|
||||||
|
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
|
Data should be stored in HWCN order. In other words, a 100×100 RGB image would
|
||||||
|
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
|
"""
|
||||||
|
struct Conv2D{F,A}
|
||||||
|
σ::F
|
||||||
|
weight::A
|
||||||
|
stride::Int
|
||||||
|
pad::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
|
init = initn, stride = 1, pad = 0) =
|
||||||
|
Conv2D(σ, param(init(k..., ch...)), stride, pad)
|
||||||
|
|
||||||
|
Flux.treelike(Conv2D)
|
||||||
|
|
||||||
|
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad))
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::Conv2D)
|
||||||
|
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")")
|
||||||
|
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4))
|
||||||
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
|
print(io, ")")
|
||||||
|
end
|
@ -2,8 +2,8 @@
|
|||||||
testmode!(m)
|
testmode!(m)
|
||||||
testmode!(m, false)
|
testmode!(m, false)
|
||||||
|
|
||||||
Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to
|
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
|
||||||
training mode with `false`).
|
(or back to training mode with `false`).
|
||||||
"""
|
"""
|
||||||
function testmode!(m, val::Bool=true)
|
function testmode!(m, val::Bool=true)
|
||||||
prefor(x -> _testmode!(x, val), m)
|
prefor(x -> _testmode!(x, val), m)
|
||||||
@ -45,6 +45,7 @@ end
|
|||||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LayerNorm(h::Integer)
|
LayerNorm(h::Integer)
|
||||||
|
|
||||||
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
||||||
@ -65,3 +66,77 @@ treelike(LayerNorm)
|
|||||||
function Base.show(io::IO, l::LayerNorm)
|
function Base.show(io::IO, l::LayerNorm)
|
||||||
print(io, "LayerNorm(", length(l.diag.α), ")")
|
print(io, "LayerNorm(", length(l.diag.α), ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
BatchNorm(dims...; λ = identity,
|
||||||
|
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
|
||||||
|
|
||||||
|
Batch Normalization Layer for [`Dense`](@ref) layer.
|
||||||
|
|
||||||
|
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||||
|
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf)
|
||||||
|
|
||||||
|
In the example of MNIST,
|
||||||
|
in order to normalize the input of other layer,
|
||||||
|
put the `BatchNorm` layer before activation function.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = Chain(
|
||||||
|
Dense(28^2, 64),
|
||||||
|
BatchNorm(64, λ = relu),
|
||||||
|
Dense(64, 10),
|
||||||
|
BatchNorm(10),
|
||||||
|
softmax)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
mutable struct BatchNorm{F,V,N}
|
||||||
|
λ::F # activation function
|
||||||
|
β::V # bias
|
||||||
|
γ::V # scale
|
||||||
|
μ # moving mean
|
||||||
|
σ # moving std
|
||||||
|
ϵ::N
|
||||||
|
momentum::N
|
||||||
|
active::Bool
|
||||||
|
end
|
||||||
|
|
||||||
|
BatchNorm(dims::Integer...; λ = identity,
|
||||||
|
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
||||||
|
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
|
||||||
|
|
||||||
|
function (BN::BatchNorm)(x)
|
||||||
|
λ, γ, β = BN.λ, BN.γ, BN.β
|
||||||
|
|
||||||
|
if !BN.active
|
||||||
|
μ = BN.μ
|
||||||
|
σ = BN.σ
|
||||||
|
else
|
||||||
|
T = eltype(x)
|
||||||
|
|
||||||
|
ϵ = T(BN.ϵ)
|
||||||
|
m = size(x, 2) # batch size
|
||||||
|
μ = mean(x, 2)
|
||||||
|
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
|
||||||
|
|
||||||
|
# update moving mean/std
|
||||||
|
mtm = T(BN.momentum)
|
||||||
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data
|
||||||
|
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
|
||||||
|
end
|
||||||
|
|
||||||
|
λ.(γ .* ((x .- μ) ./ σ) .+ β)
|
||||||
|
end
|
||||||
|
|
||||||
|
children(BN::BatchNorm) =
|
||||||
|
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
|
||||||
|
|
||||||
|
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||||
|
BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
|
||||||
|
|
||||||
|
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::BatchNorm)
|
||||||
|
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||||
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
@ -79,8 +79,8 @@ struct RNNCell{D,V}
|
|||||||
h::V
|
h::V
|
||||||
end
|
end
|
||||||
|
|
||||||
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) =
|
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
||||||
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
|
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
h = m.d(combine(x, h))
|
h = m.d(combine(x, h))
|
||||||
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
|
|||||||
h::V; c::V
|
h::V; c::V
|
||||||
end
|
end
|
||||||
|
|
||||||
function LSTMCell(in, out; init = initn)
|
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
||||||
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]...,
|
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
||||||
Dense(in+out, out, tanh, init = init),
|
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
||||||
param(init(out)), param(init(out)))
|
param(initW(out)), param(initW(out)))
|
||||||
cell.forget.b.data .= 1
|
cell.forget.b.data .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
@ -150,3 +150,49 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
|||||||
for a good overview of the internals.
|
for a good overview of the internals.
|
||||||
"""
|
"""
|
||||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||||
|
|
||||||
|
# GRU
|
||||||
|
|
||||||
|
struct GRUCell{D1,D2,V}
|
||||||
|
update::D1
|
||||||
|
reset::D1
|
||||||
|
candidate::D2
|
||||||
|
h::V
|
||||||
|
end
|
||||||
|
|
||||||
|
function GRUCell(in, out)
|
||||||
|
cell = GRUCell(Dense(in+out, out, σ),
|
||||||
|
Dense(in+out, out, σ),
|
||||||
|
Dense(in+out, out, tanh),
|
||||||
|
param(initn(out)))
|
||||||
|
return cell
|
||||||
|
end
|
||||||
|
|
||||||
|
function (m::GRUCell)(h, x)
|
||||||
|
x′ = combine(x, h)
|
||||||
|
z = m.update(x′)
|
||||||
|
r = m.reset(x′)
|
||||||
|
h̃ = m.candidate(combine(r.*h, x))
|
||||||
|
h = (1.-z).*h .+ z.*h̃
|
||||||
|
return h, h
|
||||||
|
end
|
||||||
|
|
||||||
|
hidden(m::GRUCell) = m.h
|
||||||
|
|
||||||
|
treelike(GRUCell)
|
||||||
|
|
||||||
|
Base.show(io::IO, m::GRUCell) =
|
||||||
|
print(io, "GRUCell(",
|
||||||
|
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
||||||
|
size(m.update.W, 1), ')')
|
||||||
|
|
||||||
|
"""
|
||||||
|
GRU(in::Integer, out::Integer, σ = tanh)
|
||||||
|
|
||||||
|
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||||
|
exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
|
See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
|
for a good overview of the internals.
|
||||||
|
"""
|
||||||
|
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
||||||
|
@ -4,8 +4,9 @@ using NNlib: log_fast
|
|||||||
|
|
||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) =
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
-sum(y .* log_fast.(ŷ)) / size(y, 2)
|
return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2)
|
||||||
|
end
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
|
||||||
|
@ -18,7 +18,9 @@ end
|
|||||||
|
|
||||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
||||||
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||||
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||||
|
|
||||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||||
|
|
||||||
@ -26,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
|
|||||||
|
|
||||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||||
|
|
||||||
import NNlib.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export update!, params, train!,
|
export update!, params, train!,
|
||||||
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta
|
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
call(f, xs...) = f(xs...)
|
call(f, xs...) = f(xs...)
|
||||||
|
|
||||||
|
# note for optimisers: set to zero
|
||||||
|
# p.Δ at the end of the weigths update
|
||||||
function optimiser(ps, fs...)
|
function optimiser(ps, fs...)
|
||||||
ps = [Param(p) for p in ps]
|
ps = [Param(p) for p in ps]
|
||||||
fs = map(ps) do p
|
fs = map(ps) do p
|
||||||
@ -10,64 +12,73 @@ function optimiser(ps, fs...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
SGD(params, η = 1; decay = 0)
|
SGD(params, η = 0.1; decay = 0)
|
||||||
|
|
||||||
Classic gradient descent optimiser. For each parameter `p` and its
|
Classic gradient descent optimiser with learning rate `η`.
|
||||||
gradient `δp`, this runs `p -= η*δp`.
|
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
||||||
|
|
||||||
Supports decayed learning rate decay if the `decay` argument is provided.
|
Supports inverse decaying learning rate if the `decay` argument is provided.
|
||||||
"""
|
"""
|
||||||
SGD(ps, η = 1; decay = 0) =
|
SGD(ps, η = 0.1; decay = 0) =
|
||||||
optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η))
|
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Momentum(params, ρ, decay = 0)
|
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
|
||||||
|
|
||||||
SGD with momentum `ρ` and optional learning rate decay.
|
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
|
||||||
"""
|
"""
|
||||||
Momentum(ps, ρ; decay = 0) =
|
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
||||||
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Nesterov(params, ρ, decay = 0)
|
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
|
||||||
|
|
||||||
SGD with Nesterov momentum `ρ` and optional learning rate decay.
|
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
|
||||||
"""
|
"""
|
||||||
Nesterov(ps, ρ; decay = 0) =
|
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
||||||
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0)
|
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
||||||
|
|
||||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||||
choice for recurrent networks.
|
choice for recurrent networks.
|
||||||
"""
|
"""
|
||||||
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
||||||
optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
|
|
||||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||||
"""
|
"""
|
||||||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||||
optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0)
|
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
||||||
|
|
||||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||||
Parameters don't need tuning.
|
Parameters don't need tuning.
|
||||||
"""
|
"""
|
||||||
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) =
|
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
|
||||||
optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0)
|
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
||||||
|
|
||||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||||
tuning.
|
tuning.
|
||||||
"""
|
"""
|
||||||
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
|
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
||||||
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
|
||||||
|
|
||||||
|
"""
|
||||||
|
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
|
|
||||||
|
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||||
|
tuning.
|
||||||
|
"""
|
||||||
|
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||||
|
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
@ -1,74 +1,97 @@
|
|||||||
function descent(p::Param, η::Real)
|
function descent(p::Param, η::Real)
|
||||||
function ()
|
function ()
|
||||||
p.x .-= p.Δ .* η
|
@. p.x -= η * p.Δ
|
||||||
p.Δ .= 0
|
@. p.Δ = 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function momentum(p::Param, ρ::Real)
|
function momentum(p::Param, ρ, η)
|
||||||
mo = zeros(p.x)
|
v = zeros(p.x)
|
||||||
() -> p.Δ .= mo .= ρ .* mo .+ p.Δ
|
|
||||||
end
|
|
||||||
|
|
||||||
function nesterov(p::Param, ρ::Real)
|
|
||||||
mo = zeros(p.x)
|
|
||||||
function ()
|
function ()
|
||||||
mo .= ρ .* mo .+ p.Δ
|
@. v = ρ * v - η * p.Δ
|
||||||
p.Δ .= ρ .* mo .+ p.Δ
|
@. p.Δ = -v
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function clip(p::Param, thresh::Real)
|
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
||||||
() -> clamp!(p.Δ, -thresh, thresh)
|
function nesterov(p::Param, ρ, η)
|
||||||
end
|
v = zeros(p.x)
|
||||||
|
|
||||||
function weightdecay(p::Param, γ::Real)
|
|
||||||
() -> p.Δ .+= γ .* p.x
|
|
||||||
end
|
|
||||||
|
|
||||||
function invdecay(p::Param, γ::Real)
|
|
||||||
n = 0
|
|
||||||
function ()
|
function ()
|
||||||
p.Δ .*= 1 / (1 + γ * n)
|
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
||||||
n += 1
|
@. v = ρ*v - η*p.Δ
|
||||||
|
@. p.Δ = -d
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x) .+ ϵ
|
acc = zeros(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= η / √acc
|
@. p.Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x) .+ ϵ
|
acc = zeros(p.x) .+ ϵ
|
||||||
function ()
|
function ()
|
||||||
@. acc += p.Δ ^ 2
|
@. acc += p.Δ^2
|
||||||
@. p.Δ *= η / √acc
|
@. p.Δ *= η / √acc
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8)
|
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x) .+ ϵ
|
acc = zeros(p.x)
|
||||||
Δacc = zeros(p.x) .+ ϵ
|
Δacc = zeros(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= √Δacc / √acc
|
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||||
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2
|
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zeros(p.x)
|
||||||
vt = zeros(p.x) .+ ϵ
|
vt = zeros(p.x)
|
||||||
β1p, β2p = β1, β2
|
β1p, β2p = β1, β2
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||||
@. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η
|
@. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η
|
||||||
β1p *= β1
|
β1p *= β1
|
||||||
β2p *= β2
|
β2p *= β2
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
|
mt = zeros(p.x)
|
||||||
|
vt = zeros(p.x) .+ ϵ
|
||||||
|
v̂t = zeros(p.x) .+ ϵ
|
||||||
|
function ()
|
||||||
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
|
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||||
|
@. v̂t = max.(v̂t, vt)
|
||||||
|
@. p.Δ = η * mt / √v̂t
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
||||||
|
|
||||||
|
function expdecay(p::Param, γ::Real)
|
||||||
|
if γ != 0
|
||||||
|
return () -> p.Δ .+= γ .* p.x
|
||||||
|
else
|
||||||
|
return () -> nothing
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function invdecay(p::Param, γ::Real)
|
||||||
|
if γ != 0
|
||||||
|
n = 0
|
||||||
|
return () -> begin
|
||||||
|
p.Δ .*= 1 / (1 + γ * n)
|
||||||
|
n += 1
|
||||||
|
end
|
||||||
|
else
|
||||||
|
return () -> nothing
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -1,15 +1,24 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!
|
using Flux.Tracker: back!, value
|
||||||
|
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, data, opt; cb = () -> ())
|
train!(loss, data, opt)
|
||||||
|
|
||||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||||
backpropagation and calls the optimizer `opt` and the callback `cb`
|
backpropagation and calls the optimizer `opt`.
|
||||||
(i.e. `opt()` and `cb()`).
|
|
||||||
|
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
||||||
|
every 10 seconds:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
Flux.train!(loss, data, opt,
|
||||||
|
cb = throttle(() -> println("training"), 10))
|
||||||
|
```
|
||||||
|
|
||||||
|
The callback can return `:stop` to interrupt the training loop.
|
||||||
|
|
||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
isinf(l.data[]) && error("Loss is Inf")
|
isinf(value(l)) && error("Loss is Inf")
|
||||||
isnan(l.data[]) && error("Loss is NaN")
|
isnan(value(l)) && error("Loss is NaN")
|
||||||
back!(l)
|
back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb()
|
cb() == :stop && break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -58,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
|
# TODO decide if keeping both data and value. The problem is TrackedScalar
|
||||||
value(x) = x
|
value(x) = x
|
||||||
value(x::TrackedArray) = data(x)
|
value(x::TrackedArray) = data(x)
|
||||||
value(x::TrackedScalar) = data(x)[]
|
value(x::TrackedScalar) = data(x)[]
|
||||||
@ -69,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
|
|||||||
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
||||||
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
||||||
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
||||||
|
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
print(io, "TrackedArray{…,$A}")
|
print(io, "TrackedArray{…,$A}")
|
||||||
@ -91,7 +93,7 @@ include("back.jl")
|
|||||||
include("lib.jl")
|
include("lib.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
import NNlib.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||||
|
|
||||||
|
@ -12,16 +12,17 @@ function scan(x::TrackedArray)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
back_(f, y, args...) = back(f, args...)
|
||||||
back(::Call{Void}, Δ) = nothing
|
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||||
|
back_(::Call{Void}, y, Δ) = nothing
|
||||||
|
|
||||||
function back(x::TrackedArray, Δ)
|
function back(x::TrackedArray, Δ)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
x.grad .+= Δ
|
x.grad .+= Δ
|
||||||
ref == 0 && back(x.f, x.grad)
|
ref == 0 && back_(x.f, x.data, x.grad)
|
||||||
else
|
else
|
||||||
ref == 0 && back(x.f, Δ)
|
ref == 0 && back_(x.f, x.data, Δ)
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -35,6 +36,9 @@ end
|
|||||||
|
|
||||||
# Interface methods
|
# Interface methods
|
||||||
|
|
||||||
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||||
|
# and `back` will silently fail to update.
|
||||||
|
|
||||||
function back!(x::TrackedArray, Δ)
|
function back!(x::TrackedArray, Δ)
|
||||||
scan(x)
|
scan(x)
|
||||||
back(x, Δ)
|
back(x, Δ)
|
||||||
|
@ -48,6 +48,12 @@ function back(::typeof(vcat), Δ, xs...)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||||
|
TrackedArray(Call(reshape, xs, dims...))
|
||||||
|
|
||||||
|
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||||
|
back(xs, reshape(Δ, size(xs)))
|
||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||||
@ -62,6 +68,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
|||||||
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
||||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||||
|
|
||||||
|
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||||
|
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||||
|
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||||
|
|
||||||
|
function back(::typeof(dot), Δ, xs, ys)
|
||||||
|
@back(xs, Δ.*ys)
|
||||||
|
@back(ys, Δ.*xs)
|
||||||
|
end
|
||||||
|
|
||||||
# Hacks to get std working
|
# Hacks to get std working
|
||||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||||
@ -74,7 +89,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
|||||||
|
|
||||||
# BLAS
|
# BLAS
|
||||||
|
|
||||||
for f in :[*, Ac_mul_B].args
|
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||||
@eval begin
|
@eval begin
|
||||||
import Base.$f
|
import Base.$f
|
||||||
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||||
@ -98,7 +113,12 @@ end
|
|||||||
|
|
||||||
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||||
@back(a, A_mul_Bt(Δ, data(b))')
|
@back(a, A_mul_Bt(Δ, data(b))')
|
||||||
@back(b, *(data(a), Δ))
|
@back(b, data(a)*Δ)
|
||||||
|
end
|
||||||
|
|
||||||
|
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||||
|
@back(a, Δ * data(b))
|
||||||
|
@back(b, At_mul_B(data(a), Δ)')
|
||||||
end
|
end
|
||||||
|
|
||||||
# Fast path for matrix-vector
|
# Fast path for matrix-vector
|
||||||
@ -113,12 +133,36 @@ end
|
|||||||
|
|
||||||
# NNlib
|
# NNlib
|
||||||
|
|
||||||
import NNlib: softmax, ∇softmax
|
using NNlib
|
||||||
|
import NNlib: softmax, ∇softmax, conv2d, pool
|
||||||
|
|
||||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||||
|
|
||||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||||
|
|
||||||
|
# TODO: can store kwargs efficiently in namedtuples
|
||||||
|
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
|
||||||
|
|
||||||
|
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||||
|
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||||
|
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||||
|
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||||
|
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
|
||||||
|
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||||
|
|
||||||
|
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
|
||||||
|
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||||
|
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||||
|
end
|
||||||
|
|
||||||
|
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
|
||||||
|
|
||||||
|
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
|
||||||
|
TrackedArray(Call(_pool, x, window, padding, mode))
|
||||||
|
|
||||||
|
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
|
||||||
|
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
|
||||||
|
|
||||||
# Broadcasting
|
# Broadcasting
|
||||||
|
|
||||||
using ForwardDiff: Dual, partials
|
using ForwardDiff: Dual, partials
|
||||||
@ -134,9 +178,11 @@ dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
|||||||
|
|
||||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||||
|
out = broadcast(f, dargs...)
|
||||||
|
eltype(out) <: Dual || return out
|
||||||
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||||
# Works around a 0.6 type inference issue
|
# Works around a 0.6 type inference issue
|
||||||
b = Broadcasted(broadcast(f, dargs...))
|
b = Broadcasted(out)
|
||||||
TrackedArray(Call(b, args...), b())
|
TrackedArray(Call(b, args...), b())
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
|
|||||||
return grads
|
return grads
|
||||||
end
|
end
|
||||||
|
|
||||||
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))
|
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))
|
||||||
|
30
src/utils.jl
30
src/utils.jl
@ -1,8 +1,8 @@
|
|||||||
# Arrays
|
# Arrays
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
|
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||||
|
|
||||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
|
||||||
@ -93,13 +93,14 @@ but if you'd like to disable the execution on the leading edge, pass
|
|||||||
function throttle(f, timeout; leading=true, trailing=false)
|
function throttle(f, timeout; leading=true, trailing=false)
|
||||||
cooldown = true
|
cooldown = true
|
||||||
later = nothing
|
later = nothing
|
||||||
|
result = nothing
|
||||||
|
|
||||||
function throttled(args...; kwargs...)
|
function throttled(args...; kwargs...)
|
||||||
yield()
|
yield()
|
||||||
|
|
||||||
if cooldown
|
if cooldown
|
||||||
if leading
|
if leading
|
||||||
f(args...; kwargs...)
|
result = f(args...; kwargs...)
|
||||||
else
|
else
|
||||||
later = () -> f(args...; kwargs...)
|
later = () -> f(args...; kwargs...)
|
||||||
end
|
end
|
||||||
@ -114,9 +115,28 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||||||
cooldown = true
|
cooldown = true
|
||||||
end
|
end
|
||||||
elseif trailing
|
elseif trailing
|
||||||
later = () -> f(args...; kwargs...)
|
later = () -> (result = f(args...; kwargs...))
|
||||||
end
|
end
|
||||||
|
|
||||||
nothing
|
return result
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
J = jacobian(m,x)
|
||||||
|
|
||||||
|
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||||
|
"""
|
||||||
|
function jacobian(m,x)
|
||||||
|
xp = param(x)
|
||||||
|
y = m(xp)
|
||||||
|
k = length(y)
|
||||||
|
n = length(x)
|
||||||
|
J = Matrix{eltype(x)}(n,k)
|
||||||
|
for i = 1:k
|
||||||
|
Flux.back!(y[i]) # Populate gradient accumulator
|
||||||
|
J[:,i] = xp.grad
|
||||||
|
xp.grad .*= 0 # Reset gradient accumulator
|
||||||
|
end
|
||||||
|
J'
|
||||||
|
end
|
||||||
|
@ -1,3 +1,8 @@
|
|||||||
using Flux.Data
|
using Flux.Data
|
||||||
|
using Base.Test
|
||||||
|
|
||||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||||
|
|
||||||
|
@test length(CMUDict.phones()) == 39
|
||||||
|
|
||||||
|
@test length(CMUDict.symbols()) == 84
|
||||||
|
@ -26,3 +26,55 @@ using Flux: testmode!
|
|||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a == 0, y) == 0
|
@test count(a->a == 0, y) == 0
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "BatchNorm" begin
|
||||||
|
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]')
|
||||||
|
|
||||||
|
@test m.β.data == [0, 0] # initβ(2)
|
||||||
|
@test m.γ.data == [1, 1] # initγ(2)
|
||||||
|
# initial m.σ is 1
|
||||||
|
# initial m.μ is 0
|
||||||
|
@test m.active
|
||||||
|
|
||||||
|
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
# julia> x
|
||||||
|
# 2×3 Array{Float64,2}:
|
||||||
|
# 1.0 3.0 5.0
|
||||||
|
# 2.0 4.0 6.0
|
||||||
|
#
|
||||||
|
# μ of batch will be
|
||||||
|
# (1. + 3. + 5.) / 3 = 3
|
||||||
|
# (2. + 4. + 6.) / 3 = 4
|
||||||
|
#
|
||||||
|
# ∴ update rule with momentum:
|
||||||
|
# .1 * 3 + 0 = .3
|
||||||
|
# .1 * 4 + 0 = .4
|
||||||
|
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||||
|
|
||||||
|
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
|
# 2×1 Array{Float64,2}:
|
||||||
|
# 1.14495
|
||||||
|
# 1.14495
|
||||||
|
@test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
x′ = m(x).data
|
||||||
|
@test x′[1] ≈ (1 - 0.3) / 1.1449489742783179
|
||||||
|
end
|
||||||
|
|
||||||
|
# with activation function
|
||||||
|
let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]')
|
||||||
|
@test m.active
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
x′ = m(x).data
|
||||||
|
@test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
26
test/layers/stateless.jl
Normal file
26
test/layers/stateless.jl
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
using Flux: onehotbatch, mse, crossentropy
|
||||||
|
|
||||||
|
@testset "losses" begin
|
||||||
|
# First, regression-style y's
|
||||||
|
y = [1, 1, 0, 0]
|
||||||
|
y_hat = [.9, .1, .1, .9]
|
||||||
|
|
||||||
|
@testset "mse" begin
|
||||||
|
@test mse(y_hat, y) ≈ (.1^2 + .9^2)/2
|
||||||
|
end
|
||||||
|
|
||||||
|
# Now onehot y's
|
||||||
|
y = onehotbatch([1, 1, 0, 0], 0:1)
|
||||||
|
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
||||||
|
y_logloss = 1.203972804325936
|
||||||
|
|
||||||
|
@testset "crossentropy" begin
|
||||||
|
@test crossentropy(y_hat, y) ≈ y_logloss
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "weighted_crossentropy" begin
|
||||||
|
@test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss
|
||||||
|
@test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2
|
||||||
|
@test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199
|
||||||
|
end
|
||||||
|
end
|
29
test/optimise.jl
Normal file
29
test/optimise.jl
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
using Flux.Optimise
|
||||||
|
using Flux.Tracker
|
||||||
|
|
||||||
|
@testset "Optimise" begin
|
||||||
|
w = randn(10, 10)
|
||||||
|
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
||||||
|
w′ = param(randn(10, 10))
|
||||||
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
|
opt = Opt([w′])
|
||||||
|
for t=1:10^5
|
||||||
|
l = loss(rand(10))
|
||||||
|
back!(l)
|
||||||
|
opt()
|
||||||
|
end
|
||||||
|
@test Flux.mse(w, w′) < 0.01
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Training Loop" begin
|
||||||
|
i = 0
|
||||||
|
l = param(1)
|
||||||
|
|
||||||
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
|
Iterators.repeated((), 100),
|
||||||
|
()->(),
|
||||||
|
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
||||||
|
|
||||||
|
@test 3 < i < 50
|
||||||
|
end
|
@ -5,5 +5,8 @@ using Flux, Base.Test
|
|||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
|
include("layers/stateless.jl")
|
||||||
|
include("optimise.jl")
|
||||||
|
include("data.jl")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: gradcheck
|
using Flux.Tracker: gradcheck
|
||||||
|
using NNlib
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||||
@ -10,6 +11,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
|
|
||||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||||
|
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||||
|
|
||||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||||
|
|
||||||
@ -37,9 +39,17 @@ end
|
|||||||
@test gradtest(x -> std(x), rand(5,5))
|
@test gradtest(x -> std(x), rand(5,5))
|
||||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||||
|
|
||||||
|
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||||
|
|
||||||
@test gradtest(rand(5)) do x
|
@test gradtest(rand(5)) do x
|
||||||
y = x.^2
|
y = x.^2
|
||||||
2y + x
|
2y + x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||||
|
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
|
||||||
|
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
|
||||||
|
|
||||||
|
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using Flux: throttle
|
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||||
|
|
||||||
@testset "Throttle" begin
|
@testset "Throttle" begin
|
||||||
@testset "default behaviour" begin
|
@testset "default behaviour" begin
|
||||||
@ -47,3 +47,35 @@ using Flux: throttle
|
|||||||
@test a == [1, 3]
|
@test a == [1, 3]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Jacobian" begin
|
||||||
|
A = param(randn(2,2))
|
||||||
|
x = randn(2)
|
||||||
|
m(x) = A*x
|
||||||
|
y = m(x)
|
||||||
|
J = jacobian(m,x)
|
||||||
|
@test J ≈ A.data
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Initialization" begin
|
||||||
|
# Set random seed so that these tests don't fail randomly
|
||||||
|
srand(0)
|
||||||
|
# initn() should yield a kernel with stddev ~= 1e-2
|
||||||
|
v = initn(10, 10)
|
||||||
|
@test std(v) > 0.9*1e-2
|
||||||
|
@test std(v) < 1.1*1e-2
|
||||||
|
|
||||||
|
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||||
|
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||||
|
for (n_in, n_out) in [(100, 100), (100, 400)]
|
||||||
|
v = glorot_uniform(n_in, n_out)
|
||||||
|
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
||||||
|
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
||||||
|
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
||||||
|
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
||||||
|
|
||||||
|
v = glorot_normal(n_in, n_out)
|
||||||
|
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||||
|
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user