diff --git a/README.md b/README.md index c622df38..f46ae344 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Флукс -[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![Join the chat at https://gitter.im/FluxML](https://badges.gitter.im/FluxML/Lobby.svg)](https://gitter.im/FluxML/Lobby) [Slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866) +[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![Join the chat at https://gitter.im/FluxML](https://badges.gitter.im/FluxML/Lobby.svg)](https://gitter.im/FluxML/Lobby) [Slack](https://slackinvite.julialang.org/) Flux is a refreshing approach to machine learning. It provides lightweight abstractions on top of Julia's native GPU and AD support, while remaining fully hackable (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)). diff --git a/REQUIRE b/REQUIRE index d124b931..8e718a92 100644 --- a/REQUIRE +++ b/REQUIRE @@ -3,5 +3,6 @@ DataFlow 0.2.1 Juno MacroTools 0.3.3 NNlib -ForwardDiff +ForwardDiff 0.5.0 Requires +Adapt diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 1fd87d41..cb0c6615 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense +Conv2D ``` ## Recurrent Layers @@ -37,6 +38,7 @@ These layers don't affect the structure of the network but may improve training ```@docs Flux.testmode! +BatchNorm Dropout LayerNorm ``` diff --git a/src/Flux.jl b/src/Flux.jl index 7671ddd2..75d2b2b3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,12 +7,14 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, - SGD, ADAM, Momentum, Nesterov, +export Chain, Dense, RNN, LSTM, GRU, Conv2D, + Dropout, LayerNorm, BatchNorm, + SGD, ADAM, Momentum, Nesterov, AMSGrad, param, params, mapleaves using NNlib -export σ, sigmoid, relu, leakyrelu, elu, swish, softmax +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, + conv2d, maxpool2d, avgpool2d include("tracker/Tracker.jl") using .Tracker @@ -26,6 +28,7 @@ include("treelike.jl") include("layers/stateless.jl") include("layers/basic.jl") +include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index 88b9c6c0..4307f211 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -23,17 +23,17 @@ end function symbols() load() - Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")), + Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")), "\n", keep = false)) end function rawdict() load() Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in - filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) + filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n")))) end -validword(s) = ismatch(r"^[\w-\.]+$", s) +validword(s) = ismatch(r"^[\w\-\.]+$", s) cmudict() = filter((s, ps) -> validword(s), rawdict()) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index aa101c43..9f458ab4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -63,8 +63,10 @@ struct Dense{F,S,T} b::T end -Dense(in::Integer, out::Integer, σ = identity; init = initn) = - Dense(σ, param(init(out, in)), param(init(out))) +function Dense(in::Integer, out::Integer, σ = identity; + initW = glorot_uniform, initb = zeros) + return Dense(σ, param(initW(out, in)), param(initb(out))) +end treelike(Dense) diff --git a/src/layers/conv.jl b/src/layers/conv.jl new file mode 100644 index 00000000..85b05894 --- /dev/null +++ b/src/layers/conv.jl @@ -0,0 +1,33 @@ +""" + Conv2D(size, in=>out) + Conv2d(size, in=>out, relu) + +Standard convolutional layer. `size` should be a tuple like `(2, 2)`. +`in` and `out` specify the number of input and output channels respectively. + +Data should be stored in HWCN order. In other words, a 100×100 RGB image would +be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. + +Takes the keyword arguments `pad` and `stride`. +""" +struct Conv2D{F,A} + σ::F + weight::A + stride::Int + pad::Int +end + +Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = initn, stride = 1, pad = 0) = + Conv2D(σ, param(init(k..., ch...)), stride, pad) + +Flux.treelike(Conv2D) + +(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad)) + +function Base.show(io::IO, l::Conv2D) + print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")") + print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index d296b0a3..a018a073 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -2,8 +2,8 @@ testmode!(m) testmode!(m, false) -Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to -training mode with `false`). +Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode +(or back to training mode with `false`). """ function testmode!(m, val::Bool=true) prefor(x -> _testmode!(x, val), m) @@ -45,6 +45,7 @@ end _testmode!(a::Dropout, test) = (a.active = !test) """ + LayerNorm(h::Integer) A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be @@ -65,3 +66,77 @@ treelike(LayerNorm) function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm(", length(l.diag.α), ")") end + +""" + BatchNorm(dims...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) + +Batch Normalization Layer for [`Dense`](@ref) layer. + +See [Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) + +In the example of MNIST, +in order to normalize the input of other layer, +put the `BatchNorm` layer before activation function. + +```julia +m = Chain( + Dense(28^2, 64), + BatchNorm(64, λ = relu), + Dense(64, 10), + BatchNorm(10), + softmax) +``` +""" +mutable struct BatchNorm{F,V,N} + λ::F # activation function + β::V # bias + γ::V # scale + μ # moving mean + σ # moving std + ϵ::N + momentum::N + active::Bool +end + +BatchNorm(dims::Integer...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = + BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true) + +function (BN::BatchNorm)(x) + λ, γ, β = BN.λ, BN.γ, BN.β + + if !BN.active + μ = BN.μ + σ = BN.σ + else + T = eltype(x) + + ϵ = T(BN.ϵ) + m = size(x, 2) # batch size + μ = mean(x, 2) + σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) + + # update moving mean/std + mtm = T(BN.momentum) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) + end + + λ.(γ .* ((x .- μ) ./ σ) .+ β) +end + +children(BN::BatchNorm) = + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + +mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + +_testmode!(BN::BatchNorm, test) = (BN.active = !test) + +function Base.show(io::IO, l::BatchNorm) + print(io, "BatchNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 599776ce..e4eb0c3d 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -79,8 +79,8 @@ struct RNNCell{D,V} h::V end -RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) = - RNNCell(Dense(in+out, out, σ, init = init), param(init(out))) +RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) = + RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out))) function (m::RNNCell)(h, x) h = m.d(combine(x, h)) @@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V} h::V; c::V end -function LSTMCell(in, out; init = initn) - cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]..., - Dense(in+out, out, tanh, init = init), - param(init(out)), param(init(out))) +function LSTMCell(in, out; initW = glorot_uniform, initb = zeros) + cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]..., + Dense(in+out, out, tanh, initW = initW, initb = initb), + param(initW(out)), param(initW(out))) cell.forget.b.data .= 1 return cell end @@ -150,3 +150,49 @@ See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) + +# GRU + +struct GRUCell{D1,D2,V} + update::D1 + reset::D1 + candidate::D2 + h::V +end + +function GRUCell(in, out) + cell = GRUCell(Dense(in+out, out, σ), + Dense(in+out, out, σ), + Dense(in+out, out, tanh), + param(initn(out))) + return cell +end + +function (m::GRUCell)(h, x) + x′ = combine(x, h) + z = m.update(x′) + r = m.reset(x′) + h̃ = m.candidate(combine(r.*h, x)) + h = (1.-z).*h .+ z.*h̃ + return h, h +end + +hidden(m::GRUCell) = m.h + +treelike(GRUCell) + +Base.show(io::IO, m::GRUCell) = + print(io, "GRUCell(", + size(m.update.W, 2) - size(m.update.W, 1), ", ", + size(m.update.W, 1), ')') + +""" + GRU(in::Integer, out::Integer, σ = tanh) + +Gated Recurrent Unit layer. Behaves like an RNN but generally +exhibits a longer memory span over sequences. + +See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) +for a good overview of the internals. +""" +GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index edbdec58..63c40cb8 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,8 +4,9 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) + return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2) +end @deprecate logloss(x, y) crossentropy(x, y) diff --git a/src/onehot.jl b/src/onehot.jl index f94fb93e..b1a1a970 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -18,7 +18,9 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] @@ -26,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) -import NNlib.adapt +import Adapt.adapt adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5f144b65..acec542e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export update!, params, train!, - SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta + SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 0b2a25ae..42b05dc8 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -1,5 +1,7 @@ call(f, xs...) = f(xs...) +# note for optimisers: set to zero +# p.Δ at the end of the weigths update function optimiser(ps, fs...) ps = [Param(p) for p in ps] fs = map(ps) do p @@ -10,64 +12,73 @@ function optimiser(ps, fs...) end """ - SGD(params, η = 1; decay = 0) + SGD(params, η = 0.1; decay = 0) -Classic gradient descent optimiser. For each parameter `p` and its -gradient `δp`, this runs `p -= η*δp`. +Classic gradient descent optimiser with learning rate `η`. +For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. -Supports decayed learning rate decay if the `decay` argument is provided. +Supports inverse decaying learning rate if the `decay` argument is provided. """ -SGD(ps, η = 1; decay = 0) = - optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η)) +SGD(ps, η = 0.1; decay = 0) = + optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η)) """ - Momentum(params, ρ, decay = 0) + Momentum(params, η = 0.01; ρ = 0.9, decay = 0) -SGD with momentum `ρ` and optional learning rate decay. +SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay. """ -Momentum(ps, ρ; decay = 0) = - optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) +Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) = + optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1)) """ - Nesterov(params, ρ, decay = 0) + Nesterov(params, η = 0.01; ρ = 0.9, decay = 0) -SGD with Nesterov momentum `ρ` and optional learning rate decay. +SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay. """ -Nesterov(ps, ρ; decay = 0) = - optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) +Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) = + optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1)) """ - RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0) + RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) optimiser. Parameters other than learning rate don't need tuning. Often a good choice for recurrent networks. """ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. """ ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0) + ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. Parameters don't need tuning. """ -ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) = + optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) + ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need tuning. """ -ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) = + optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1)) + +""" + AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need +tuning. +""" +AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index abc54090..c09e6131 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,74 +1,97 @@ function descent(p::Param, η::Real) function () - p.x .-= p.Δ .* η - p.Δ .= 0 + @. p.x -= η * p.Δ + @. p.Δ = 0 end end -function momentum(p::Param, ρ::Real) - mo = zeros(p.x) - () -> p.Δ .= mo .= ρ .* mo .+ p.Δ -end - -function nesterov(p::Param, ρ::Real) - mo = zeros(p.x) +function momentum(p::Param, ρ, η) + v = zeros(p.x) function () - mo .= ρ .* mo .+ p.Δ - p.Δ .= ρ .* mo .+ p.Δ + @. v = ρ * v - η * p.Δ + @. p.Δ = -v end end -function clip(p::Param, thresh::Real) - () -> clamp!(p.Δ, -thresh, thresh) -end - -function weightdecay(p::Param, γ::Real) - () -> p.Δ .+= γ .* p.x -end - -function invdecay(p::Param, γ::Real) - n = 0 +# Ref. https://arxiv.org/pdf/1212.0901.pdf +function nesterov(p::Param, ρ, η) + v = zeros(p.x) function () - p.Δ .*= 1 / (1 + γ * n) - n += 1 + d = @. ρ^2 * v - (1+ρ) * η * p.Δ + @. v = ρ*v - η*p.Δ + @. p.Δ = -d end end function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zeros(p.x) .+ ϵ + acc = zeros(p.x) function () - @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ *= η / √acc + @. acc = ρ * acc + (1 - ρ) * p.Δ^2 + @. p.Δ *= η / (√acc + ϵ) end end function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () - @. acc += p.Δ ^ 2 + @. acc += p.Δ^2 @. p.Δ *= η / √acc end end -function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8) - acc = zeros(p.x) .+ ϵ - Δacc = zeros(p.x) .+ ϵ +function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) + acc = zeros(p.x) + Δacc = zeros(p.x) function () - @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ *= √Δacc / √acc - @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2 - end + @. acc = ρ * acc + (1 - ρ) * p.Δ^2 + @. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ) + @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 + end end function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) - vt = zeros(p.x) .+ ϵ + vt = zeros(p.x) β1p, β2p = β1, β2 function () @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η + @. vt = β2 * vt + (1 - β2) * p.Δ^2 + @. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η β1p *= β1 β2p *= β2 end end + +function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + vt = zeros(p.x) .+ ϵ + v̂t = zeros(p.x) .+ ϵ + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 + @. v̂t = max.(v̂t, vt) + @. p.Δ = η * mt / √v̂t + end +end + +clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) + +function expdecay(p::Param, γ::Real) + if γ != 0 + return () -> p.Δ .+= γ .* p.x + else + return () -> nothing + end +end + +function invdecay(p::Param, γ::Real) + if γ != 0 + n = 0 + return () -> begin + p.Δ .*= 1 / (1 + γ * n) + n += 1 + end + else + return () -> nothing + end +end diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 0809e86b..31812fa0 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,15 +1,24 @@ using Juno -using Flux.Tracker: back! +using Flux.Tracker: back!, value runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) """ - train!(loss, data, opt; cb = () -> ()) + train!(loss, data, opt) For each datapoint `d` in `data` computes the gradient of `loss(d...)` through -backpropagation and calls the optimizer `opt` and the callback `cb` -(i.e. `opt()` and `cb()`). +backpropagation and calls the optimizer `opt`. + +Takes a callback as keyword argument `cb`. For example, this will print "training" +every 10 seconds: + +```julia +Flux.train!(loss, data, opt, + cb = throttle(() -> println("training"), 10)) +``` + +The callback can return `:stop` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ @@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l.data[]) && error("Loss is Inf") - isnan(l.data[]) && error("Loss is NaN") + isinf(value(l)) && error("Loss is Inf") + isnan(value(l)) && error("Loss is NaN") back!(l) opt() - cb() + cb() == :stop && break end end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 74ed2d75..aa2bc6ea 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -58,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) +# TODO decide if keeping both data and value. The problem is TrackedScalar value(x) = x value(x::TrackedArray) = data(x) value(x::TrackedScalar) = data(x)[] @@ -69,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x) Base.isless(x::TrackedScalar, y) = isless(value(x), y) Base.isless(x, y::TrackedScalar) = isless(x, value(y)) Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) +Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...) Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}") @@ -91,7 +93,7 @@ include("back.jl") include("lib.jl") include("numeric.jl") -import NNlib.adapt +import Adapt.adapt adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad)) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 39810069..b4cd27c6 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -12,16 +12,17 @@ function scan(x::TrackedArray) return end -back(c::Call, Δ) = back(c.func, Δ, c.args...) -back(::Call{Void}, Δ) = nothing +back_(f, y, args...) = back(f, args...) +back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) +back_(::Call{Void}, y, Δ) = nothing function back(x::TrackedArray, Δ) ref = x.ref -= 1 if isdefined(x, :grad) x.grad .+= Δ - ref == 0 && back(x.f, x.grad) + ref == 0 && back_(x.f, x.data, x.grad) else - ref == 0 && back(x.f, Δ) + ref == 0 && back_(x.f, x.data, Δ) end return end @@ -35,6 +36,9 @@ end # Interface methods +# TODO: if an error occurs in `back` the refcounts will be broken +# and `back` will silently fail to update. + function back!(x::TrackedArray, Δ) scan(x) back(x, Δ) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index be77634b..72b863d6 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -48,6 +48,12 @@ function back(::typeof(vcat), Δ, xs...) end end +Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = + TrackedArray(Call(reshape, xs, dims...)) + +back(::typeof(reshape), Δ, xs::TrackedArray, _...) = + back(xs, reshape(Δ, size(xs))) + # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) @@ -62,6 +68,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) +LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) + +function back(::typeof(dot), Δ, xs, ys) + @back(xs, Δ.*ys) + @back(ys, Δ.*xs) +end + # Hacks to get std working Base.std(x::TrackedArray; mean = Base.mean(x)) = sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) @@ -74,7 +89,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) = # BLAS -for f in :[*, Ac_mul_B].args +for f in :[*, Ac_mul_B, A_mul_Bc].args @eval begin import Base.$f $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) @@ -98,7 +113,12 @@ end function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) @back(a, A_mul_Bt(Δ, data(b))') - @back(b, *(data(a), Δ)) + @back(b, data(a)*Δ) +end + +function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) + @back(a, Δ * data(b)) + @back(b, At_mul_B(data(a), Δ)') end # Fast path for matrix-vector @@ -113,12 +133,36 @@ end # NNlib -import NNlib: softmax, ∇softmax +using NNlib +import NNlib: softmax, ∇softmax, conv2d, pool softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) +# TODO: can store kwargs efficiently in namedtuples +_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad) + +conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) +conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) +conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) + +function back(::typeof(_conv2d), Δ, x, w, stride, pad) + @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad)) + @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad)) +end + +_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad) + +pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) = + TrackedArray(Call(_pool, x, window, padding, mode)) + +back_(::typeof(_pool), y, Δ, x, k, pad, mode) = + back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad)) + # Broadcasting using ForwardDiff: Dual, partials @@ -134,9 +178,11 @@ dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs)) function tracked_broadcast(f, args::Vararg{Any,N}) where N dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) + out = broadcast(f, dargs...) + eltype(out) <: Dual || return out # TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...)) # Works around a 0.6 type inference issue - b = Broadcasted(broadcast(f, dargs...)) + b = Broadcasted(out) TrackedArray(Call(b, args...), b()) end diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 68211aa3..cbcd3ad8 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...) return grads end -gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6)) +gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5)) diff --git a/src/utils.jl b/src/utils.jl index f822c111..bba3e416 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,8 +1,8 @@ # Arrays initn(dims...) = randn(dims...)/100 - -flatten(xs) = reshape(xs, size(xs, 1), :) +glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims))) +glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims))) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) @@ -93,13 +93,14 @@ but if you'd like to disable the execution on the leading edge, pass function throttle(f, timeout; leading=true, trailing=false) cooldown = true later = nothing + result = nothing function throttled(args...; kwargs...) yield() if cooldown if leading - f(args...; kwargs...) + result = f(args...; kwargs...) else later = () -> f(args...; kwargs...) end @@ -114,9 +115,28 @@ function throttle(f, timeout; leading=true, trailing=false) cooldown = true end elseif trailing - later = () -> f(args...; kwargs...) + later = () -> (result = f(args...; kwargs...)) end - nothing + return result end end + +""" + J = jacobian(m,x) + +Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])` +""" +function jacobian(m,x) + xp = param(x) + y = m(xp) + k = length(y) + n = length(x) + J = Matrix{eltype(x)}(n,k) + for i = 1:k + Flux.back!(y[i]) # Populate gradient accumulator + J[:,i] = xp.grad + xp.grad .*= 0 # Reset gradient accumulator + end + J' +end diff --git a/test/data.jl b/test/data.jl index 1b93ab3c..5a4c9ce6 100644 --- a/test/data.jl +++ b/test/data.jl @@ -1,3 +1,8 @@ using Flux.Data +using Base.Test @test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args + +@test length(CMUDict.phones()) == 39 + +@test length(CMUDict.symbols()) == 84 diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 5a302a51..118a5700 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -26,3 +26,55 @@ using Flux: testmode! y = m(x) @test count(a->a == 0, y) == 0 end + +@testset "BatchNorm" begin + let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]') + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + # initial m.σ is 1 + # initial m.μ is 0 + @test m.active + + # @test m(x).data ≈ [-1 -1; 0 0; 1 1]' + m(x) + + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + # + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test m.μ ≈ reshape([0.3, 0.4], 2, 1) + + # julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.14495 + # 1.14495 + @test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ (1 - 0.3) / 1.1449489742783179 + end + + # with activation function + let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]') + @test m.active + m(x) + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179) + end +end diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl new file mode 100644 index 00000000..23304eb1 --- /dev/null +++ b/test/layers/stateless.jl @@ -0,0 +1,26 @@ +using Flux: onehotbatch, mse, crossentropy + +@testset "losses" begin + # First, regression-style y's + y = [1, 1, 0, 0] + y_hat = [.9, .1, .1, .9] + + @testset "mse" begin + @test mse(y_hat, y) ≈ (.1^2 + .9^2)/2 + end + + # Now onehot y's + y = onehotbatch([1, 1, 0, 0], 0:1) + y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]' + y_logloss = 1.203972804325936 + + @testset "crossentropy" begin + @test crossentropy(y_hat, y) ≈ y_logloss + end + + @testset "weighted_crossentropy" begin + @test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss + @test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2 + @test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199 + end +end diff --git a/test/optimise.jl b/test/optimise.jl new file mode 100644 index 00000000..66c50037 --- /dev/null +++ b/test/optimise.jl @@ -0,0 +1,29 @@ +using Flux.Optimise +using Flux.Tracker + +@testset "Optimise" begin + w = randn(10, 10) + for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] + w′ = param(randn(10, 10)) + loss(x) = Flux.mse(w*x, w′*x) + opt = Opt([w′]) + for t=1:10^5 + l = loss(rand(10)) + back!(l) + opt() + end + @test Flux.mse(w, w′) < 0.01 + end +end + +@testset "Training Loop" begin + i = 0 + l = param(1) + + Flux.train!(() -> (sleep(0.1); i += 1; l), + Iterators.repeated((), 100), + ()->(), + cb = Flux.throttle(() -> (i > 3 && :stop), 1)) + + @test 3 < i < 50 +end diff --git a/test/runtests.jl b/test/runtests.jl index efd1a462..553545e9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,8 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") include("layers/normalisation.jl") +include("layers/stateless.jl") +include("optimise.jl") +include("data.jl") end diff --git a/test/tracker.jl b/test/tracker.jl index 81a72566..90eb0af1 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,5 +1,6 @@ using Flux.Tracker, Base.Test, NNlib using Flux.Tracker: gradcheck +using NNlib gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...) gradtest(f, dims...) = gradtest(f, rand.(dims)...) @@ -10,6 +11,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) +@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @@ -37,9 +39,17 @@ end @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5)) +@test gradtest((x, y) -> x .* y, rand(5), rand(5)) + @test gradtest(rand(5)) do x y = x.^2 2y + x end +@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2)) +@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2)) +@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2)) + +@test (param([1,2,3]) .< 2) == [true, false, false] + end #testset diff --git a/test/utils.jl b/test/utils.jl index 7638fd2a..7a00b57d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using Flux: throttle +using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian @testset "Throttle" begin @testset "default behaviour" begin @@ -47,3 +47,35 @@ using Flux: throttle @test a == [1, 3] end end + +@testset "Jacobian" begin + A = param(randn(2,2)) + x = randn(2) + m(x) = A*x + y = m(x) + J = jacobian(m,x) + @test J ≈ A.data +end + +@testset "Initialization" begin + # Set random seed so that these tests don't fail randomly + srand(0) + # initn() should yield a kernel with stddev ~= 1e-2 + v = initn(10, 10) + @test std(v) > 0.9*1e-2 + @test std(v) < 1.1*1e-2 + + # glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)), + # and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = glorot_uniform(n_in, n_out) + @test minimum(v) > -1.1*sqrt(6/(n_in + n_out)) + @test minimum(v) < -0.9*sqrt(6/(n_in + n_out)) + @test maximum(v) > 0.9*sqrt(6/(n_in + n_out)) + @test maximum(v) < 1.1*sqrt(6/(n_in + n_out)) + + v = glorot_normal(n_in, n_out) + @test std(v) > 0.9*sqrt(2/(n_in + n_out)) + @test std(v) < 1.1*sqrt(2/(n_in + n_out)) + end +end