From 21ea93ffcd08c87ed5dfae5bc6645852744160fe Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 2 Nov 2017 11:44:39 +0000 Subject: [PATCH 01/26] rename treelike --- src/Flux.jl | 2 +- src/{tree.jl => treelike.jl} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/{tree.jl => treelike.jl} (100%) diff --git a/src/Flux.jl b/src/Flux.jl index 242c8b1f..ff78593f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -22,7 +22,7 @@ using .Optimise include("utils.jl") include("onehot.jl") -include("tree.jl") +include("treelike.jl") include("layers/stateless.jl") include("layers/basic.jl") diff --git a/src/tree.jl b/src/treelike.jl similarity index 100% rename from src/tree.jl rename to src/treelike.jl From efa51f02e7a7ea28d79aabe496cdb57aedbae4fd Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 11 Oct 2017 11:54:18 +0100 Subject: [PATCH 02/26] basic batch type --- src/Flux.jl | 2 ++ src/batches/Batches.jl | 7 +++++++ src/batches/batch.jl | 8 ++++++++ 3 files changed, 17 insertions(+) create mode 100644 src/batches/Batches.jl create mode 100644 src/batches/batch.jl diff --git a/src/Flux.jl b/src/Flux.jl index ff78593f..acefff19 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -31,4 +31,6 @@ include("layers/normalisation.jl") include("data/Data.jl") +include("batches/Batches.jl") + end # module diff --git a/src/batches/Batches.jl b/src/batches/Batches.jl new file mode 100644 index 00000000..066f4d1c --- /dev/null +++ b/src/batches/Batches.jl @@ -0,0 +1,7 @@ +module Batches + +import ..Flux + +include("batch.jl") + +end diff --git a/src/batches/batch.jl b/src/batches/batch.jl new file mode 100644 index 00000000..5a2eb82e --- /dev/null +++ b/src/batches/batch.jl @@ -0,0 +1,8 @@ +struct Batch{T,A,M} + data::A + mask::M +end + +Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask) + +Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs))) From 97244e0a68fa8cbae17f8065160126897a674009 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Sat, 4 Nov 2017 13:27:32 +0100 Subject: [PATCH 03/26] Allow array of optimisers to train! This allows an array of optimisers to be sent to `train!` --- src/optimise/train.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 2a2ec5eb..0809e86b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,8 +1,8 @@ using Juno using Flux.Tracker: back! -tocb(f) = f -tocb(fs::AbstractVector) = () -> foreach(call, fs) +runall(f) = f +runall(fs::AbstractVector) = () -> foreach(call, fs) """ train!(loss, data, opt; cb = () -> ()) @@ -11,10 +11,11 @@ For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt` and the callback `cb` (i.e. `opt()` and `cb()`). -Multiple callbacks can be passed to `cb` as an array. +Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, data, opt; cb = () -> ()) - cb = tocb(cb) + cb = runall(cb) + opt = runall(opt) @progress for d in data l = loss(d...) isinf(l.data[]) && error("Loss is Inf") From d6423eefe54b8ba822ed49b8b5c0d52dbe58ae1d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 7 Nov 2017 19:34:27 +0000 Subject: [PATCH 04/26] matrix-vector fast path --- src/tracker/Tracker.jl | 2 ++ src/tracker/lib.jl | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index d6fa6f35..5e26a051 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -38,6 +38,8 @@ TrackedArray(c::Call) = TrackedArray(c, c()) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) +isleaf(x::TrackedArray) = x.f == Call(nothing) + param(xs) = TrackedArray(AbstractFloat.(xs)) param(xs::Real) = param(fill(xs)) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index a90eb932..2ee5d659 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -79,6 +79,16 @@ function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) @back(b, At_mul_B(data(a), Δ)) end +# Fast path for matrix-vector +function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector) + if isleaf(W) + W.grad .+= Δ .* data(x).' + else + back(W, A_mul_Bt(Δ, data(x))) + end + @back(x, At_mul_B(data(W), Δ)) +end + # NNlib import NNlib: softmax, ∇softmax From d4229c4815a265d2ba084dc2b5b6db264cea497d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 7 Nov 2017 19:34:35 +0000 Subject: [PATCH 05/26] useful params method --- src/treelike.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/treelike.jl b/src/treelike.jl index 899fccea..097ccdc6 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -35,3 +35,5 @@ function params(m) prefor(p -> p isa TrackedArray && push!(ps, p), m) return ps end + +params(m...) = params(m) From fcd091e8f06fc7a8824c4ca12d38dd23a4da4f08 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 8 Nov 2017 22:00:19 +0000 Subject: [PATCH 06/26] Ac_mul_B derivatives --- src/tracker/lib.jl | 28 ++++++++++++++++++++-------- test/tracker.jl | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 2ee5d659..aab26dfe 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -1,5 +1,3 @@ -import Base: * - toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y @@ -66,19 +64,33 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) = # BLAS -a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) -a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) -a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) +for f in :[*, Ac_mul_B].args + @eval begin + import Base.$f + $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b)) + $f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) -a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) -a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) -a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) + $f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b)) + $f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) + + $f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b)) + $f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) + end +end function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) @back(a, A_mul_Bt(Δ, data(b))) @back(b, At_mul_B(data(a), Δ)) end +function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) + @back(a, A_mul_Bt(Δ, data(b))') + @back(b, *(data(a), Δ)) +end + # Fast path for matrix-vector function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector) if isleaf(W) diff --git a/test/tracker.jl b/test/tracker.jl index 52a73a07..69f37367 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) +@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) + @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> softmax(x).*(1:3), 3) From bdf02e42aee308125cf3a9a7a05bb3f7d24d4942 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 8 Nov 2017 22:00:31 +0000 Subject: [PATCH 07/26] test tweaks --- test/tracker.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/tracker.jl b/test/tracker.jl index 69f37367..f2a369f8 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -39,18 +39,4 @@ end 2y + x end -for T in [Float32, Float64] - @test isa(param(T(1)), TrackedArray{T, 0}) - @test isa(param(rand(T, 2)), TrackedArray{T, 1}) - @test isa(param(rand(T, 2,2)), TrackedArray{T, 2}) -end - -# TODO: do we wand this behaviour ?? -F = typeof(AbstractFloat(1)) -for T in [Int32, Int64] - @test isa(param(T(1)), TrackedArray{F, 0}) - @test isa(param(rand(T, 2)), TrackedArray{F, 1}) - @test isa(param(rand(T, 2,2)), TrackedArray{F, 2}) -end - end #testset From e5d99d784ec23d32e679b9f5a72cacb32ac5d361 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 9 Nov 2017 14:53:26 +0000 Subject: [PATCH 08/26] fixes #79 --- src/onehot.jl | 11 +++++++++-- src/tracker/Tracker.jl | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 5414773c..f8061063 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,3 +1,5 @@ +import Base: * + struct OneHotVector <: AbstractVector{Bool} ix::UInt32 of::UInt32 @@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix -Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] +A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} height::Int @@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] -Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] +A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -47,3 +49,8 @@ argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractMatrix, l...) = squeeze(mapslices(y -> argmax(y, l...), y, 1), 1) + +# Ambiguity hack + +a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b)) +a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b)) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 5e26a051..3a64fcb7 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,6 +1,6 @@ module Tracker -export TrackedArray, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, param, back! data(x) = x istracked(x) = false From 2cb94981a0176f070eb2dec31c00ef125613ce3f Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 27 Oct 2017 12:05:37 +0100 Subject: [PATCH 09/26] gpu-ready log --- src/Flux.jl | 1 + src/layers/stateless.jl | 4 +-- src/numeric.jl | 80 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 src/numeric.jl diff --git a/src/Flux.jl b/src/Flux.jl index acefff19..ce3861e5 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,6 +21,7 @@ include("optimise/Optimise.jl") using .Optimise include("utils.jl") +include("numeric.jl") include("onehot.jl") include("treelike.jl") diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 3931c216..56d18349 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -3,12 +3,12 @@ mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log.(ŷ)) / size(y, 2) + -sum(y .* log_fast.(ŷ)) / size(y, 2) @deprecate logloss(x, y) crossentropy(x, y) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) logŷ = logŷ .- maximum(logŷ, 1) - ypred = logŷ .- log.(sum(exp.(logŷ), 1)) + ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1)) -sum(y .* ypred) / size(y, 2) end diff --git a/src/numeric.jl b/src/numeric.jl new file mode 100644 index 00000000..9c444043 --- /dev/null +++ b/src/numeric.jl @@ -0,0 +1,80 @@ +using Base.Math: @horner, significand_bits, exponent_raw_max, exponent_bias + +if VERSION < v"0.7.0-DEV.1430" + using Base.Math.fpinttype +else + using Base.uinttype +end + +# log_fast from +# https://github.com/musm/SLEEF.jl/blob/c9dcd2eb090d69ec40790f19798c5fef2aba2616/src/log.jl + +const MLN2 = 6.931471805599453094172321214581765680755001343602552541206800094933936219696955e-01 # log(2) + +@inline float2integer(d::Float64) = (reinterpret(Int64, d) >> significand_bits(Float64)) % Int +@inline float2integer(d::Float32) = (reinterpret(Int32, d) >> significand_bits(Float32)) % Int + +@inline function ilogb2k(d::T) where {T<:Union{Float32,Float64}} + (float2integer(d) & exponent_raw_max(T)) - exponent_bias(T) +end + +@inline function ldexp3k(x::T, e::Int) where {T<:Union{Float32,Float64}} + if VERSION < v"0.7.0-DEV.1430" + reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % fpinttype(T)) + else + reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % uinttype(T)) + end +end + +""" + log_fast(x) +Compute the natural logarithm of `x`. The inverse of the natural logarithm is +the natural expoenential function `exp(x)` +""" +function log_fast end + +let +global log_fast + +c8d = 0.153487338491425068243146 +c7d = 0.152519917006351951593857 +c6d = 0.181863266251982985677316 +c5d = 0.222221366518767365905163 +c4d = 0.285714294746548025383248 +c3d = 0.399999999950799600689777 +c2d = 0.6666666666667778740063 +c1d = 2.0 + +c5f = 0.2392828464508056640625f0 +c4f = 0.28518211841583251953125f0 +c3f = 0.400005877017974853515625f0 +c2f = 0.666666686534881591796875f0 +c1f = 2f0 + +global @inline log_fast_kernel(x::Float64) = @horner x c1d c2d c3d c4d c5d c6d c7d c8d +global @inline log_fast_kernel(x::Float32) = @horner x c1f c2f c3f c4f c5f + +function log_fast(d::T) where {T<:Union{Float32,Float64}} + o = d < realmin(T) + o && (d *= T(Int64(1) << 32) * T(Int64(1) << 32)) + + e = ilogb2k(d * T(1.0/0.75)) + m = ldexp3k(d, -e) + o && (e -= 64) + + x = (m - 1) / (m + 1) + x2 = x * x + + t = log_fast_kernel(x2) + + x = x * t + T(MLN2) * e + + isinf(d) && (x = T(Inf)) + (d < 0 || isnan(d)) && (x = T(NaN)) + d == 0 && (x = -T(Inf)) + + return x +end +end + +log_fast(x::Union{Int32,Int64}) = log_fast(float(x)) From e0657d93ecccf1b1ac924a42909a0c79b9433df4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 9 Nov 2017 15:03:57 +0000 Subject: [PATCH 10/26] mv numeric.jl to nnlib --- src/Flux.jl | 1 - src/layers/stateless.jl | 2 ++ src/numeric.jl | 80 ----------------------------------------- 3 files changed, 2 insertions(+), 81 deletions(-) delete mode 100644 src/numeric.jl diff --git a/src/Flux.jl b/src/Flux.jl index ce3861e5..acefff19 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,7 +21,6 @@ include("optimise/Optimise.jl") using .Optimise include("utils.jl") -include("numeric.jl") include("onehot.jl") include("treelike.jl") diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 56d18349..834068aa 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,3 +1,5 @@ +using NNlib: log_fast + # Cost functions mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) diff --git a/src/numeric.jl b/src/numeric.jl deleted file mode 100644 index 9c444043..00000000 --- a/src/numeric.jl +++ /dev/null @@ -1,80 +0,0 @@ -using Base.Math: @horner, significand_bits, exponent_raw_max, exponent_bias - -if VERSION < v"0.7.0-DEV.1430" - using Base.Math.fpinttype -else - using Base.uinttype -end - -# log_fast from -# https://github.com/musm/SLEEF.jl/blob/c9dcd2eb090d69ec40790f19798c5fef2aba2616/src/log.jl - -const MLN2 = 6.931471805599453094172321214581765680755001343602552541206800094933936219696955e-01 # log(2) - -@inline float2integer(d::Float64) = (reinterpret(Int64, d) >> significand_bits(Float64)) % Int -@inline float2integer(d::Float32) = (reinterpret(Int32, d) >> significand_bits(Float32)) % Int - -@inline function ilogb2k(d::T) where {T<:Union{Float32,Float64}} - (float2integer(d) & exponent_raw_max(T)) - exponent_bias(T) -end - -@inline function ldexp3k(x::T, e::Int) where {T<:Union{Float32,Float64}} - if VERSION < v"0.7.0-DEV.1430" - reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % fpinttype(T)) - else - reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % uinttype(T)) - end -end - -""" - log_fast(x) -Compute the natural logarithm of `x`. The inverse of the natural logarithm is -the natural expoenential function `exp(x)` -""" -function log_fast end - -let -global log_fast - -c8d = 0.153487338491425068243146 -c7d = 0.152519917006351951593857 -c6d = 0.181863266251982985677316 -c5d = 0.222221366518767365905163 -c4d = 0.285714294746548025383248 -c3d = 0.399999999950799600689777 -c2d = 0.6666666666667778740063 -c1d = 2.0 - -c5f = 0.2392828464508056640625f0 -c4f = 0.28518211841583251953125f0 -c3f = 0.400005877017974853515625f0 -c2f = 0.666666686534881591796875f0 -c1f = 2f0 - -global @inline log_fast_kernel(x::Float64) = @horner x c1d c2d c3d c4d c5d c6d c7d c8d -global @inline log_fast_kernel(x::Float32) = @horner x c1f c2f c3f c4f c5f - -function log_fast(d::T) where {T<:Union{Float32,Float64}} - o = d < realmin(T) - o && (d *= T(Int64(1) << 32) * T(Int64(1) << 32)) - - e = ilogb2k(d * T(1.0/0.75)) - m = ldexp3k(d, -e) - o && (e -= 64) - - x = (m - 1) / (m + 1) - x2 = x * x - - t = log_fast_kernel(x2) - - x = x * t + T(MLN2) * e - - isinf(d) && (x = T(Inf)) - (d < 0 || isnan(d)) && (x = T(NaN)) - d == 0 && (x = -T(Inf)) - - return x -end -end - -log_fast(x::Union{Int32,Int64}) = log_fast(float(x)) From 8991ce028ca02ed9d4c3286eba3468d2fe6e9ec1 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Tue, 14 Nov 2017 17:32:16 +0100 Subject: [PATCH 11/26] Fix bug in rmsprop and adadelta MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `@. p.Δ = η * p.Δ / √acc` parses correctly while `@. p.Δ /= √acc*η` seems to parse like `@. p.Δ /= (√acc*η)`, hence the step size was de facto interpreted as `1/η` --- src/optimise/optimisers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 95b31b98..1ffd8982 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -38,7 +38,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ /= √acc * η + @. p.Δ = η * p.Δ / √acc end end @@ -46,7 +46,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc += p.Δ ^ 2 - @. p.Δ /= √acc * η + @. p.Δ = η * p.Δ / √acc end end From 187fddc11c2f0733d5e6a1644c2167d8bde590ab Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 21 Nov 2017 12:29:02 +0100 Subject: [PATCH 12/26] doc fixes --- docs/src/models/layers.md | 1 + docs/src/training/optimisers.md | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 5d5d2ee8..f92f751a 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -36,5 +36,6 @@ swish These layers don't affect the structure of the network but may improve training times or reduce overfitting. ```@docs +Flux.testmode! Dropout ``` diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 3af5604b..56f511e4 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -58,8 +58,5 @@ All optimisers return a function that, when called, will update the parameters p SGD Momentum Nesterov -RMSProp ADAM -ADAGrad -ADADelta ``` From e51268caf57cb259a74a6f7f71bc4235b8891d90 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 21 Nov 2017 12:59:39 +0100 Subject: [PATCH 13/26] mention treelike --- docs/src/models/basics.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 6fbd0792..02225279 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -151,3 +151,13 @@ m = Chain(x -> x^2, x -> x+1) m(5) # => 26 ``` + +## Layer helpers + +Flux provides a set of helpers for custom layers, which you can enable by calling + +```julia +Flux.treelike(Affine) +``` + +This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). From 979949d01adab7bec0711771785eb02b6109788f Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 21 Nov 2017 15:25:09 +0100 Subject: [PATCH 14/26] style --- src/optimise/optimisers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 1ffd8982..abc54090 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -38,7 +38,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ = η * p.Δ / √acc + @. p.Δ *= η / √acc end end @@ -46,7 +46,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc += p.Δ ^ 2 - @. p.Δ = η * p.Δ / √acc + @. p.Δ *= η / √acc end end From 11d53781b254bbb0fbe8a1c1313a3b05efc61112 Mon Sep 17 00:00:00 2001 From: skariel Date: Tue, 10 Oct 2017 23:33:37 +0300 Subject: [PATCH 15/26] adding layer normalization --- src/layers/basic.jl | 30 ++++++++++++++++++++++++++++++ src/layers/stateless.jl | 23 +++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 969a261c..03a340df 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -78,3 +78,33 @@ function Base.show(io::IO, l::Dense) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end + +""" + ElementwiseLinear(in::Integer) + +Creates an element-wise linear transformation layer with learnable +vectors α and β: + + y = α .* x .+ b + +The input `x` must be a vector of length `in`, or a batch of vectors represented +as an `in × N` matrix. The out `y` will be a vector or batch of length `in`. +""" +struct ElementwiseLinear{T} + α::T + β::T +end + +ElementwiseLinear(in::Integer; initα = ones, initβ = zeros) = + ElementwiseLinear(param(initα(in)), param(initβ(in))) + +treelike(ElementwiseLinear) + +function (a::ElementwiseLinear)(x) + α, β = a.α, a.β + α.*x .+ β +end + +function Base.show(io::IO, l::ElementwiseLinear) + print(io, "ElementwiseLinear(", length(l.α), ")") +end diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 3931c216..8d0276e8 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -12,3 +12,26 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) ypred = logŷ .- log.(sum(exp.(logŷ), 1)) -sum(y .* ypred) / size(y, 2) end + +""" + layernormalization(α=1.0, β=0.0) + +Creates a normalization layer based on https://arxiv.org/pdf/1607.06450.pdf + +The differences are: + +1) std here divides by N-1 (as does std in Julia) vs the paper N +2) this layer α and β are constant numbers (i.e. not learnable vectors) + +To achieve the same effect of learnable vectors α and β oe can use +the ElementwiseLinear layer +""" +function layernormalization(α=1.0, β=0.0) + function layer(y) + _mean = mean(y) + _std = sqrt.(sum((y.-_mean).^2) ./ (length(y)-1)) + _std /= α + _mean -= β*_std + return (y .- _mean) ./ _std + end +end From b06884b9123d9168104602c9855e4bc046bdecab Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 23 Oct 2017 12:53:07 +0100 Subject: [PATCH 16/26] LayerNorm tweaks --- docs/src/models/layers.md | 1 + src/Flux.jl | 2 +- src/layers/basic.jl | 19 +++++++++---------- src/layers/normalisation.jl | 22 ++++++++++++++++++++++ src/layers/stateless.jl | 24 ++++++------------------ 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index f92f751a..1fd87d41 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -38,4 +38,5 @@ These layers don't affect the structure of the network but may improve training ```@docs Flux.testmode! Dropout +LayerNorm ``` diff --git a/src/Flux.jl b/src/Flux.jl index acefff19..df4b1636 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, Dropout, +export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, SGD, ADAM, Momentum, Nesterov, param, params, mapleaves diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 03a340df..3c47b595 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -80,31 +80,30 @@ function Base.show(io::IO, l::Dense) end """ - ElementwiseLinear(in::Integer) + Diagonal(in::Integer) Creates an element-wise linear transformation layer with learnable vectors α and β: y = α .* x .+ b -The input `x` must be a vector of length `in`, or a batch of vectors represented -as an `in × N` matrix. The out `y` will be a vector or batch of length `in`. +The input `x` must be a array where `size(x, 1) == in`. """ -struct ElementwiseLinear{T} +struct Diagonal{T} α::T β::T end -ElementwiseLinear(in::Integer; initα = ones, initβ = zeros) = - ElementwiseLinear(param(initα(in)), param(initβ(in))) +Diagonal(in::Integer; initα = ones, initβ = zeros) = + Diagonal(param(initα(in)), param(initβ(in))) -treelike(ElementwiseLinear) +treelike(Diagonal) -function (a::ElementwiseLinear)(x) +function (a::Diagonal)(x) α, β = a.α, a.β α.*x .+ β end -function Base.show(io::IO, l::ElementwiseLinear) - print(io, "ElementwiseLinear(", length(l.α), ")") +function Base.show(io::IO, l::Diagonal) + print(io, "Diagonal(", length(l.α), ")") end diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 08c21428..d296b0a3 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -43,3 +43,25 @@ function (a::Dropout)(x) end _testmode!(a::Dropout, test) = (a.active = !test) + +""" + LayerNorm(h::Integer) + +A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be +used with recurrent hidden states of size `h`. Normalises the mean/stddev of +each input before applying a per-neuron gain/bias. +""" +struct LayerNorm{T} + diag::Diagonal{T} +end + +LayerNorm(h::Integer) = + LayerNorm(Diagonal(h)) + +treelike(LayerNorm) + +(a::LayerNorm)(x) = a.diag(normalise(x)) + +function Base.show(io::IO, l::LayerNorm) + print(io, "LayerNorm(", length(l.diag.α), ")") +end diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 8d0276e8..2a4b9a7c 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -14,24 +14,12 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) end """ - layernormalization(α=1.0, β=0.0) + normalise(x::AbstractVecOrMat) -Creates a normalization layer based on https://arxiv.org/pdf/1607.06450.pdf - -The differences are: - -1) std here divides by N-1 (as does std in Julia) vs the paper N -2) this layer α and β are constant numbers (i.e. not learnable vectors) - -To achieve the same effect of learnable vectors α and β oe can use -the ElementwiseLinear layer +Normalise each column of `x` to mean 0 and standard deviation 1. """ -function layernormalization(α=1.0, β=0.0) - function layer(y) - _mean = mean(y) - _std = sqrt.(sum((y.-_mean).^2) ./ (length(y)-1)) - _std /= α - _mean -= β*_std - return (y .- _mean) ./ _std - end +function normalise(x::AbstractVecOrMat) + μ′ = mean(x, 1) + σ′ = std(x, 1, mean = μ′) + return (x .- μ′) ./ σ′ end From 351d3d4771da08e53d2a2f89547f91d5fdb47beb Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 21 Nov 2017 17:04:04 +0100 Subject: [PATCH 17/26] std derivative --- src/layers/basic.jl | 4 ++-- src/tracker/lib.jl | 6 ++++++ test/tracker.jl | 3 +++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3c47b595..aa101c43 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -83,9 +83,9 @@ end Diagonal(in::Integer) Creates an element-wise linear transformation layer with learnable -vectors α and β: +vectors `α` and `β`: - y = α .* x .+ b + y = α .* x .+ β The input `x` must be a array where `size(x, 1) == in`. """ diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index aab26dfe..5065a40d 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -58,6 +58,12 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) +# Hacks to get std working +Base.std(x::TrackedArray; mean = Base.mean(x)) = + sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) +Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) = + sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1)) + back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data)) back(::typeof(mean), Δ, xs::TrackedArray, region) = back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) diff --git a/test/tracker.jl b/test/tracker.jl index f2a369f8..81a72566 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -34,6 +34,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) end +@test gradtest(x -> std(x), rand(5,5)) +@test gradtest(x -> std(x, 1), rand(5,5)) + @test gradtest(rand(5)) do x y = x.^2 2y + x From 13b934c2500b8e39ac24c834079b562057dede5a Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 12 Oct 2017 10:31:38 +0200 Subject: [PATCH 18/26] improve optimizers --- src/data/cmudict.jl | 3 +- src/optimise/interface.jl | 50 +++++++++++----------- src/optimise/optimisers.jl | 85 +++++++++++++++++++++----------------- src/tracker/Tracker.jl | 2 + test/optimise.jl | 19 +++++++++ test/runtests.jl | 1 + 6 files changed, 98 insertions(+), 62 deletions(-) create mode 100644 test/optimise.jl diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index 88b9c6c0..a23c6a3d 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -33,7 +33,8 @@ function rawdict() filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) end -validword(s) = ismatch(r"^[\w-\.]+$", s) +# validword(s) = ismatch(r"^[\w-\.]+$", s) +validword(s) = ismatch(r"^\[\w-\.\]+$", s) cmudict() = filter((s, ps) -> validword(s), rawdict()) diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 0b2a25ae..47b0f62c 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -1,5 +1,7 @@ call(f, xs...) = f(xs...) +# note for optimisers: set to zero +# p.Δ at the end of the weigths update function optimiser(ps, fs...) ps = [Param(p) for p in ps] fs = map(ps) do p @@ -10,64 +12,64 @@ function optimiser(ps, fs...) end """ - SGD(params, η = 1; decay = 0) + SGD(params, η = 0.1; decay = 0) -Classic gradient descent optimiser. For each parameter `p` and its -gradient `δp`, this runs `p -= η*δp`. +Classic gradient descent optimiser with learning rate `η`. +For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. -Supports decayed learning rate decay if the `decay` argument is provided. +Supports inverse decaying learning rate if the `decay` argument is provided. """ -SGD(ps, η = 1; decay = 0) = - optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η)) +SGD(ps, η = 0.1; decay = 0) = + optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η)) """ - Momentum(params, ρ, decay = 0) + Momentum(params, η = 0.01; ρ = 0.9, decay = 0) -SGD with momentum `ρ` and optional learning rate decay. +SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay. """ -Momentum(ps, ρ; decay = 0) = - optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) +Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) = + optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1)) """ - Nesterov(params, ρ, decay = 0) + Nesterov(params, η = 0.01; ρ = 0.9, decay = 0) -SGD with Nesterov momentum `ρ` and optional learning rate decay. +SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay. """ -Nesterov(ps, ρ; decay = 0) = - optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) +Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) = + optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1)) """ - RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0) + RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) optimiser. Parameters other than learning rate don't need tuning. Often a good choice for recurrent networks. """ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. """ ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0) + ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. Parameters don't need tuning. """ -ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) = + optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) + ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need tuning. """ -ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) = + optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index abc54090..7cf271b6 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,74 +1,85 @@ function descent(p::Param, η::Real) function () - p.x .-= p.Δ .* η - p.Δ .= 0 + @. p.x -= η * p.Δ + @. p.Δ = 0 end end -function momentum(p::Param, ρ::Real) - mo = zeros(p.x) - () -> p.Δ .= mo .= ρ .* mo .+ p.Δ -end - -function nesterov(p::Param, ρ::Real) - mo = zeros(p.x) +function momentum(p::Param, ρ, η) + v = zeros(p.x) function () - mo .= ρ .* mo .+ p.Δ - p.Δ .= ρ .* mo .+ p.Δ + @. v = ρ * v - η * p.Δ + @. p.Δ = -v end end -function clip(p::Param, thresh::Real) - () -> clamp!(p.Δ, -thresh, thresh) -end - -function weightdecay(p::Param, γ::Real) - () -> p.Δ .+= γ .* p.x -end - -function invdecay(p::Param, γ::Real) - n = 0 +# Ref. https://arxiv.org/pdf/1212.0901.pdf +function nesterov(p::Param, ρ, η) + v = zeros(p.x) function () - p.Δ .*= 1 / (1 + γ * n) - n += 1 + d = @. ρ^2 * v - (1+ρ) * η * p.Δ + @. v = ρ*v - η*p.Δ + @. p.Δ = -d end end function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zeros(p.x) .+ ϵ + acc = zeros(p.x) function () - @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ *= η / √acc + @. acc = ρ * acc + (1 - ρ) * p.Δ^2 + @. p.Δ *= η / (√acc + ϵ) end end function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () - @. acc += p.Δ ^ 2 + @. acc += p.Δ^2 @. p.Δ *= η / √acc end end -function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8) - acc = zeros(p.x) .+ ϵ - Δacc = zeros(p.x) .+ ϵ +function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) + acc = zeros(p.x) + Δacc = zeros(p.x) function () - @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ *= √Δacc / √acc - @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2 - end + @. acc = ρ * acc + (1 - ρ) * p.Δ^2 + @. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ) + @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 + end end function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) - vt = zeros(p.x) .+ ϵ + vt = zeros(p.x) β1p, β2p = β1, β2 function () @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η + @. vt = β2 * vt + (1 - β2) * p.Δ^2 + @. p.Δ = mt / (1 - β1p) / (sqrt(vt / (1 - β2p)) + ϵ) * η β1p *= β1 β2p *= β2 end end + +clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) + +function expdecay(p::Param, γ::Real) + if γ != 0 + return () -> p.Δ .+= γ .* p.x + else + return () -> nothing + end +end + +function invdecay(p::Param, γ::Real) + if γ != 0 + n = 0 + return () -> begin + p.Δ .*= 1 / (1 + γ * n) + n += 1 + end + else + return () -> nothing + end +end \ No newline at end of file diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 3a64fcb7..57bdc447 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -58,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) +# TODO decide if keeping both data and value. The problem is TrackedScalar value(x) = x value(x::TrackedArray) = data(x) value(x::TrackedScalar) = data(x)[] @@ -69,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x) Base.isless(x::TrackedScalar, y) = isless(value(x), y) Base.isless(x, y::TrackedScalar) = isless(x, value(y)) Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) +Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...) Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}") diff --git a/test/optimise.jl b/test/optimise.jl new file mode 100644 index 00000000..85fd53f9 --- /dev/null +++ b/test/optimise.jl @@ -0,0 +1,19 @@ +using Flux.Optimise +using Flux.Tracker + +@testset "Optimise" begin + loss(x) = sum(x.^2) + η = 0.1 + # RMSProp gets stuck + for OPT in [SGD, Nesterov, Momentum, ADAM, ADAGrad, ADADelta] + x = param(randn(10)) + opt = OPT == ADADelta ? OPT([x]) : OPT([x], η) + for t=1:10000 + l = loss(x) + back!(l) + opt() + l.data[] < 1e-10 && break + end + @test loss(x) ≈ 0. atol=1e-7 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index efd1a462..bdd1f2d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") include("layers/normalisation.jl") +include("optimise.jl") end From 2d33f19346b48dd76559926b62ba1dd7cd978ba7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 29 Nov 2017 16:45:50 +0000 Subject: [PATCH 19/26] onehot unk arg --- src/onehot.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index f8061063..f94fb93e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -42,7 +42,14 @@ function onehot(l, labels) OneHotVector(i, length(labels)) end -onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) +function onehot(l, labels, unk) + i = findfirst(labels, l) + i > 0 || return onehot(unk, labels) + OneHotVector(i, length(labels)) +end + +onehotbatch(ls, labels, unk...) = + OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) argmax(y::AbstractVector, labels = 1:length(y)) = labels[findfirst(y, maximum(y))] From 19039f48819835bf01ea6f2f69792f53dfe7d4f8 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 30 Nov 2017 13:37:38 +0000 Subject: [PATCH 20/26] export sigmoid --- src/Flux.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index df4b1636..7671ddd2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,7 +12,7 @@ export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, param, params, mapleaves using NNlib -export σ, relu, leakyrelu, elu, swish, softmax +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax include("tracker/Tracker.jl") using .Tracker From cab235a57863558aa060a28776f8934d5a0a0ed4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 30 Nov 2017 13:51:31 +0000 Subject: [PATCH 21/26] gpu compat --- src/tracker/Tracker.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 3a64fcb7..74ed2d75 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -40,7 +40,7 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) isleaf(x::TrackedArray) = x.f == Call(nothing) -param(xs) = TrackedArray(AbstractFloat.(xs)) +param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs)) param(xs::Real) = param(fill(xs)) istracked(x::TrackedArray) = true From 36001d085a3f9175eaee572e8b8532410a8ebf50 Mon Sep 17 00:00:00 2001 From: baggepinnen Date: Mon, 4 Dec 2017 09:17:05 +0100 Subject: [PATCH 22/26] Implement AMSGrad optimiser --- src/optimise/Optimise.jl | 2 +- src/optimise/interface.jl | 9 +++++++++ src/optimise/optimisers.jl | 14 +++++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5f144b65..acec542e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export update!, params, train!, - SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta + SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 0b2a25ae..c6f98553 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -71,3 +71,12 @@ tuning. """ ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + + """ + AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + + [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need + tuning. + """ + AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index abc54090..12a14df4 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -67,8 +67,20 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η + @. p.Δ = √(1 - β2p) / (1 - β1p) * mt / √vt * η β1p *= β1 β2p *= β2 end end + +function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + vt = zeros(p.x) .+ ϵ + v̂t = zeros(p.x) .+ ϵ + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 + @. v̂t = max.(v̂t, vt) + @. p.Δ = η * mt / √v̂t + end +end From 41febee9c171e610336afd79e1d1480100f29a53 Mon Sep 17 00:00:00 2001 From: baggepinnen Date: Mon, 4 Dec 2017 09:34:27 +0100 Subject: [PATCH 23/26] Export and indent --- src/Flux.jl | 2 +- src/optimise/interface.jl | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index df4b1636..2ae8879f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,7 +8,7 @@ using Juno, Requires using Lazy: @forward export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, - SGD, ADAM, Momentum, Nesterov, + SGD, ADAM, Momentum, Nesterov, AMSGrad, param, params, mapleaves using NNlib diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index c6f98553..679134fe 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -47,7 +47,7 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) """ - ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. """ @@ -72,11 +72,11 @@ tuning. ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) - """ - AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) +""" + AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need - tuning. - """ - AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need +tuning. +""" +AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) From 951c21366a54ab60899f2e9955c05bd8ebaedf5b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Dec 2017 16:42:30 +0000 Subject: [PATCH 24/26] fix regex --- src/data/cmudict.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index a23c6a3d..9ec567b4 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -33,8 +33,7 @@ function rawdict() filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) end -# validword(s) = ismatch(r"^[\w-\.]+$", s) -validword(s) = ismatch(r"^\[\w-\.\]+$", s) +validword(s) = ismatch(r"^[\w\-\.]+$", s) cmudict() = filter((s, ps) -> validword(s), rawdict()) From 69cc5642b48b685bbbf109af310384f8eae917e4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Dec 2017 17:10:29 +0000 Subject: [PATCH 25/26] regression testing --- test/optimise.jl | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/test/optimise.jl b/test/optimise.jl index 85fd53f9..65bb65be 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -2,18 +2,16 @@ using Flux.Optimise using Flux.Tracker @testset "Optimise" begin - loss(x) = sum(x.^2) - η = 0.1 - # RMSProp gets stuck - for OPT in [SGD, Nesterov, Momentum, ADAM, ADAGrad, ADADelta] - x = param(randn(10)) - opt = OPT == ADADelta ? OPT([x]) : OPT([x], η) - for t=1:10000 - l = loss(x) - back!(l) - opt() - l.data[] < 1e-10 && break - end - @test loss(x) ≈ 0. atol=1e-7 + w = randn(10, 10) + for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta] + w′ = param(randn(10, 10)) + loss(x) = Flux.mse(w*x, w′*x) + opt = Opt([w′]) + for t=1:10^5 + l = loss(rand(10)) + back!(l) + opt() end + @test Flux.mse(w, w′) < 0.01 + end end From 55bbe50f32d7dfe58360da9da3832add38a8cc38 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Dec 2017 18:24:07 +0000 Subject: [PATCH 26/26] regression test --- test/optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index 65bb65be..526f0534 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta] + for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′])