From b6508e241679c76b6a02200458b46463b66d3074 Mon Sep 17 00:00:00 2001 From: Sujeet Akula Date: Thu, 26 Apr 2018 17:37:24 +1000 Subject: [PATCH 01/28] add adamax --- src/optimise/interface.jl | 9 +++++++++ src/optimise/optimisers.jl | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 42b05dc8..29068983 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) +""" + AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on +the ∞-norm. +""" +AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) + """ ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index c09e6131..569e69aa 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -62,6 +62,18 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ end end +function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + ut = zero(p.x) + β1p = β1 + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + ut = max(β2 * ut, norm(p.Δ, Inf)) + @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) + β1p *= β1 + end +end + function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) vt = zeros(p.x) .+ ϵ From 4586bda5abc26498734b0f2ec3452f0fa216d9ec Mon Sep 17 00:00:00 2001 From: Sujeet Akula Date: Thu, 26 Apr 2018 17:40:11 +1000 Subject: [PATCH 02/28] export/test adamax --- src/optimise/Optimise.jl | 2 +- test/optimise.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index acec542e..b9b5949e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export update!, params, train!, - SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + SGD, ADAM, Adamax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T diff --git a/test/optimise.jl b/test/optimise.jl index d57e4985..ae7ec8fe 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] + @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′]) From 5e5f255f81220c4180e6f6b3bd939cf952d1401a Mon Sep 17 00:00:00 2001 From: Sujeet Akula Date: Thu, 26 Apr 2018 17:42:04 +1000 Subject: [PATCH 03/28] export typo --- src/optimise/Optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index b9b5949e..c07ba218 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export update!, params, train!, - SGD, ADAM, Adamax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T From 8c042bd522ffa76bd8953675c024cfbc809d46c5 Mon Sep 17 00:00:00 2001 From: Sujeet Akula Date: Thu, 26 Apr 2018 21:12:31 +1000 Subject: [PATCH 04/28] element wise max() --- src/optimise/optimisers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 569e69aa..29b058ba 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -64,11 +64,11 @@ end function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) - ut = zero(p.x) + ut = zeros(p.x) β1p = β1 function () @. mt = β1 * mt + (1 - β1) * p.Δ - ut = max(β2 * ut, norm(p.Δ, Inf)) + @. ut = max(β2 * ut, abs(p.Δ)) @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) β1p *= β1 end From cfd29b9c767edc9c0472ad9048a7eb4fedfbbfb9 Mon Sep 17 00:00:00 2001 From: Pontus Stenetorp Date: Fri, 27 Apr 2018 22:14:01 +0100 Subject: [PATCH 05/28] Backpropagation for `maximum` and `minimum` --- src/tracker/array.jl | 31 ++++++++++++++++++++++++++++++- test/tracker.jl | 20 ++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index bb55ef73..610675a3 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -156,12 +156,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) -Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = track(mean, xs) Base.mean(xs::TrackedArray, region) = track(mean, xs, region) +Base.maximum(xs::TrackedArray) = track(maximum, xs) +Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region) +Base.minimum(xs::TrackedArray) = track(minimum, xs) +Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region) + LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) @@ -184,6 +188,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ back(::typeof(mean), Δ, xs::TrackedArray, region) = back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) +function back(::typeof(maximum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmax(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(maximum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmax(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmin(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmin(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end + # BLAS Base.diagm(x::TrackedVector) = track(diagm, x) diff --git a/test/tracker.jl b/test/tracker.jl index 0f5b6189..12ed02e5 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -55,6 +55,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) end +@testset "maximum" begin + @test gradtest(maximum, rand(2, 3)) + + @test gradtest(x -> maximum(x, 1), rand(2, 3)) + @test gradtest(x -> maximum(x, 2), rand(2, 3)) + @test gradtest(x -> maximum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4)) +end + +@testset "minimum" begin + @test gradtest(minimum, rand(2, 3)) + + @test gradtest(x -> minimum(x, 1), rand(2, 3)) + @test gradtest(x -> minimum(x, 2), rand(2, 3)) + @test gradtest(x -> minimum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4)) +end + @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5)) From 73a51400b682bd9cee981c18ed002fc81f8c2ce9 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 30 Apr 2018 12:09:15 +0100 Subject: [PATCH 06/28] better error message --- src/tracker/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index bb55ef73..89dcfac9 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -41,7 +41,7 @@ end Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") -back!(::TrackedArray) = error("Use back!(x, Δ)") +back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`") # Fallthrough methods From 4fb6bc7feadacc62a28dd96db172e9c306bebd1d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 30 Apr 2018 18:04:13 +0100 Subject: [PATCH 07/28] add note on metalhead --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 4785c55c..79e37c5f 100644 --- a/README.md +++ b/README.md @@ -79,3 +79,9 @@ For general questions and help, check out Julia's [community forum](https://disc Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here. For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel. + +## Related Packages + +Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models. + +[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets. From e186b958ddd01c04512bc6d3bbf7556c2550ba5c Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 22 Feb 2018 11:25:47 -0500 Subject: [PATCH 08/28] more exports --- src/Flux.jl | 16 +++++++++------- src/optimise/Optimise.jl | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 7746ecff..5f7bbdfe 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,26 +7,25 @@ module Flux using Juno, Requires, Reexport using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D, - Dropout, LayerNorm, BatchNorm, - SGD, ADAM, Momentum, Nesterov, AMSGrad, - param, params, mapleaves, cpu, gpu - @reexport using NNlib using NNlib: @fix include("tracker/Tracker.jl") using .Tracker -export Tracker -import .Tracker: data +using .Tracker: data +export TrackedArray, TrackedVector, TrackedMatrix, param, back! include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs +export train!, + SGD, ADAM, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad include("utils.jl") include("onehot.jl") include("treelike.jl") +export params, mapleaves, cpu, gpu, onehot, batch, glorot_normal, glorot_uniform include("layers/stateless.jl") include("layers/basic.jl") @@ -34,6 +33,9 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") +export Chain, Dense, RNN, LSTM, GRU, Conv2D, + Dropout, LayerNorm, BatchNorm + include("data/Data.jl") @require CuArrays include("cuda/cuda.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index c07ba218..5d5d9ea0 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,6 +1,6 @@ module Optimise -export update!, params, train!, +export train!, SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} From 9a7e6e9c5c012f42e5be2e3956af9179971f35fe Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 1 May 2018 12:16:56 +0100 Subject: [PATCH 09/28] hold off on some things --- src/Flux.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 5f7bbdfe..6cb92c50 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,25 +7,27 @@ module Flux using Juno, Requires, Reexport using MacroTools: @forward +export Chain, Dense, RNN, LSTM, GRU, Conv, + Dropout, LayerNorm, BatchNorm, + params, mapleaves, cpu, gpu + @reexport using NNlib using NNlib: @fix include("tracker/Tracker.jl") using .Tracker using .Tracker: data -export TrackedArray, TrackedVector, TrackedMatrix, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export train!, - SGD, ADAM, Momentum, Nesterov, +export SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad include("utils.jl") include("onehot.jl") include("treelike.jl") -export params, mapleaves, cpu, gpu, onehot, batch, glorot_normal, glorot_uniform include("layers/stateless.jl") include("layers/basic.jl") @@ -33,9 +35,6 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") -export Chain, Dense, RNN, LSTM, GRU, Conv2D, - Dropout, LayerNorm, BatchNorm - include("data/Data.jl") @require CuArrays include("cuda/cuda.jl") From 7d7d89569c74f1e0f542568111fd3ec10a785332 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 1 May 2018 12:20:36 +0100 Subject: [PATCH 10/28] rm this deprecation for 0.6 --- src/layers/conv.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 994648c2..39d3394d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -45,6 +45,3 @@ function Base.show(io::IO, l::Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end - -# v0.5 -@deprecate Conv2D(args...; kw...) Conv(args...; kw...) From 51e7e1b40fa35949054fefa70f1ca9592821a362 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:51:04 +0200 Subject: [PATCH 11/28] cat tests #184 Co-authored-by: pevnak --- test/tracker.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/tracker.jl b/test/tracker.jl index 12ed02e5..2b0e04d7 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -32,6 +32,14 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(5), rand(3), rand(8)) @test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) + +@test gradtest((i...) -> cat(1,i...), rand(5), rand(3)) +@test gradtest((i...) -> cat(1,i...), rand(5), rand(8)) +@test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2)) +@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,1)) +@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,4)) +@test gradtest((i...) -> cat(2,i...), rand(5,2),rand(5,4), rand(5,8)) + @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) From 59324c0f91a51b5de7660f4116ee5738f989cde5 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:22:59 +0200 Subject: [PATCH 12/28] hcat tests #194 Co-authored-by: Elliot Saba --- test/tracker.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/tracker.jl b/test/tracker.jl index 2b0e04d7..f39546ea 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -33,6 +33,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3), rand(8)) @test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) +@test gradtest(vcat, rand(5,2,3), rand(3,2,3), rand(8,2,3)) +@test gradtest(hcat, rand(5), rand(5), rand(5,2)) +@test gradtest(hcat, rand(5,2), rand(5,3), rand(5,5)) +@test gradtest(hcat, rand(5,2,3), rand(5,3,3), rand(5,5,3)) @test gradtest((i...) -> cat(1,i...), rand(5), rand(3)) @test gradtest((i...) -> cat(1,i...), rand(5), rand(8)) @test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2)) @@ -45,9 +49,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) -@test gradtest(kron,rand(5), rand(3)) +@test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8)) -@test gradtest(kron,rand(5,1), rand(3,1)) +@test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) From 13daaec1cbafe0b74669aee345497556f7a56623 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:54:40 +0200 Subject: [PATCH 13/28] Refactored tests --- test/tracker.jl | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/test/tracker.jl b/test/tracker.jl index f39546ea..27d395b1 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -29,20 +29,25 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) -@test gradtest(vcat, rand(5), rand(3)) -@test gradtest(vcat, rand(5), rand(3), rand(8)) -@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) - -@test gradtest(vcat, rand(5,2,3), rand(3,2,3), rand(8,2,3)) -@test gradtest(hcat, rand(5), rand(5), rand(5,2)) -@test gradtest(hcat, rand(5,2), rand(5,3), rand(5,5)) -@test gradtest(hcat, rand(5,2,3), rand(5,3,3), rand(5,5,3)) -@test gradtest((i...) -> cat(1,i...), rand(5), rand(3)) -@test gradtest((i...) -> cat(1,i...), rand(5), rand(8)) -@test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2)) -@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,1)) -@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,4)) -@test gradtest((i...) -> cat(2,i...), rand(5,2),rand(5,4), rand(5,8)) +@testset "concat" begin + @testset "vcat $i" for (i,vcatf) in enumerate((vcat, (x...) -> cat(1, x...))) + @test gradtest(vcatf, rand(5), rand(3)) + @test gradtest(vcatf, rand(5), rand(3), rand(8)) + @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) + @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) + end + @testset "hcat $i" for (i,hcatf) in enumerate((hcat, (x...) -> cat(2, x...))) + @test gradtest(hcatf, rand(5), rand(5)) + @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) + @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + end + @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) + @testset "cat($dim, ...)" for dim in 1:5 + catdim = (x...) -> cat(dim, x...) + @test gradtest(catdim, rand(5), rand(5)) + @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) + end +end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) From bcef5c4ab512fd84ef44a1e97c465a24f1001977 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:56:08 +0200 Subject: [PATCH 14/28] Support hcat and cat --- src/tracker/array.jl | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 4dfb2c6d..0bfabf36 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -96,6 +96,14 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) +Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...) +Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b) +Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b) + +Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...) +Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b) +Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b) + function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -117,6 +125,34 @@ function back(::typeof(vcat), Δ, xs...) end end +function back(::typeof(hcat), Δ, xs...) + i = fill(:, ndims(Δ)-2) + start = 0 + for xsi in xs + if ndims(xsi) == 1 + @back(xsi, Δ[:, start+1]) + else + @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) + end + start += size(xsi, 2) + end +end + +function back(::typeof(cat), Δ, dim, xs...) + i = fill(:, dim-1) + j = fill(:, ndims(Δ)-dim) + start = 0 + for xsi in xs + if ndims(xsi) < dim + a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)] + @back(xsi, Δ[a..., start+1]) + else + @back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...]) + end + start += size(xsi, dim) + end +end + Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) From eaaf5fd34c1b41b780e13d00f9ff8186e2cb4035 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Sat, 21 Apr 2018 01:10:34 +0200 Subject: [PATCH 15/28] vcat arrays with ndims>2 --- src/tracker/array.jl | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 0bfabf36..61c2d5ce 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,20 +81,9 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b) -Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...) -Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b) -Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b) - -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...) -Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b) -Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) - -Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b) -Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) -Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) -Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) +Base.vcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(vcat, a, b...) +Base.vcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(vcat, a, b) +Base.vcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(vcat, a, b) Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...) Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b) @@ -117,7 +106,7 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) end function back(::typeof(vcat), Δ, xs...) - i = Base.tail(map(_ -> :, size(Δ))) + i = fill(:, ndims(Δ)-1) start = 0 for xsi in xs @back(xsi, Δ[start+1:start+size(xsi,1), i...]) From 509a2e59f6e6e8cbfcf2f3283884fe005a33fce4 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 08:30:11 +0200 Subject: [PATCH 16/28] cat promotions and mixed ranks --- src/tracker/array.jl | 32 +++++++++++++++++--------------- test/tracker.jl | 27 +++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 61c2d5ce..89fce39e 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,17 +81,18 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(vcat, a, b...) -Base.vcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(vcat, a, b) -Base.vcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(vcat, a, b) +for f in [:vcat, :hcat] + @eval begin + Base.$f(a::TrackedArray...) = track($f, a...) + Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) -Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...) -Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b) -Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b) + # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector + Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) + end +end -Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...) -Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b) -Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b) +Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) +Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) @@ -106,21 +107,21 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) end function back(::typeof(vcat), Δ, xs...) - i = fill(:, ndims(Δ)-1) start = 0 for xsi in xs + i = map(_ -> :, size(xsi)) |> Base.tail @back(xsi, Δ[start+1:start+size(xsi,1), i...]) start += size(xsi, 1) end end function back(::typeof(hcat), Δ, xs...) - i = fill(:, ndims(Δ)-2) start = 0 for xsi in xs if ndims(xsi) == 1 @back(xsi, Δ[:, start+1]) else + i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) end start += size(xsi, 2) @@ -128,14 +129,15 @@ function back(::typeof(hcat), Δ, xs...) end function back(::typeof(cat), Δ, dim, xs...) - i = fill(:, dim-1) - j = fill(:, ndims(Δ)-dim) start = 0 for xsi in xs if ndims(xsi) < dim - a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)] - @back(xsi, Δ[a..., start+1]) + i = map(_ -> :, size(xsi)) + j = ones(Int, dim-ndims(xsi)-1) + @back(xsi, Δ[i..., j..., start+1]) else + i = fill(:, dim-1) + j = fill(:, ndims(xsi)-dim) @back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...]) end start += size(xsi, dim) diff --git a/test/tracker.jl b/test/tracker.jl index 27d395b1..d0b2375d 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -29,19 +29,42 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) +function simplepromotioncheck(f, A, B) + r0 = f(A, B) + r1 = f(param(A), B) + r2 = f(A, param(B)) + r3 = f(param(A), param(B)) + + r1 == r2 == r3 && r0 == Flux.data(r1) +end + @testset "concat" begin - @testset "vcat $i" for (i,vcatf) in enumerate((vcat, (x...) -> cat(1, x...))) + cat1(x...) = cat(1, x...) + cat2(x...) = cat(2, x...) + + @testset for vcatf in [vcat, cat1] @test gradtest(vcatf, rand(5), rand(3)) @test gradtest(vcatf, rand(5), rand(3), rand(8)) @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) + @test gradtest(vcatf, rand(5), rand(3,1)) + @test gradtest(vcatf, rand(5)', rand(2,5)) end - @testset "hcat $i" for (i,hcatf) in enumerate((hcat, (x...) -> cat(2, x...))) + + @test simplepromotioncheck(vcat, rand(5), rand(5)) + + @testset for hcatf in [hcat, cat2] @test gradtest(hcatf, rand(5), rand(5)) @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + @test gradtest(hcatf, rand(5)', rand(1,3)) + @test gradtest(hcatf, rand(5), rand(5,2)) end + + @test simplepromotioncheck(hcat, rand(5), rand(5)) + @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) + @testset "cat($dim, ...)" for dim in 1:5 catdim = (x...) -> cat(dim, x...) @test gradtest(catdim, rand(5), rand(5)) From fb685291693cf0c9dbc466557af6d3ab7f078e88 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 08:37:30 +0200 Subject: [PATCH 17/28] define back function right after forward function --- src/tracker/array.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 89fce39e..71a2d530 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,19 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -for f in [:vcat, :hcat] - @eval begin - Base.$f(a::TrackedArray...) = track($f, a...) - Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) - - # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector - Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) - end -end - -Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) -Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) - function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -106,6 +93,16 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) back(xs, Δ′) end +for f in [:vcat, :hcat] + @eval begin + Base.$f(a::TrackedArray...) = track($f, a...) + Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) + + # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector + Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) + end +end + function back(::typeof(vcat), Δ, xs...) start = 0 for xsi in xs @@ -128,6 +125,9 @@ function back(::typeof(hcat), Δ, xs...) end end +Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) +Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) + function back(::typeof(cat), Δ, dim, xs...) start = 0 for xsi in xs From 1c189c62ed9acc2a3f6784e90dee9293f5c3b965 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 09:03:54 +0200 Subject: [PATCH 18/28] cat with multiple dims #156 Co-authored-by: americast --- src/tracker/array.jl | 31 +++++++++++++++---------------- test/tracker.jl | 2 ++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 71a2d530..4650e916 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -98,7 +98,7 @@ for f in [:vcat, :hcat] Base.$f(a::TrackedArray...) = track($f, a...) Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) - # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector + # assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) end end @@ -125,22 +125,21 @@ function back(::typeof(hcat), Δ, xs...) end end -Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) -Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) +Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...) +Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...) +Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...) -function back(::typeof(cat), Δ, dim, xs...) - start = 0 - for xsi in xs - if ndims(xsi) < dim - i = map(_ -> :, size(xsi)) - j = ones(Int, dim-ndims(xsi)-1) - @back(xsi, Δ[i..., j..., start+1]) - else - i = fill(:, dim-1) - j = fill(:, ndims(xsi)-dim) - @back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...]) - end - start += size(xsi, dim) +function back(::typeof(cat), Δ, dims, Xs...) + start = ntuple(i -> 0, Val{ndims(Δ)}) + for xs in Xs + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) + + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) + + @back(xs, reshape(Δ[xs_in_Δ...],size(xs))) + + start = start .+ till_xs end end diff --git a/test/tracker.jl b/test/tracker.jl index d0b2375d..01aa03de 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -70,6 +70,8 @@ end @test gradtest(catdim, rand(5), rand(5)) @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) end + + @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) From cfdb16e609438719e480d053d90b1be88c5a48aa Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:46:01 +0200 Subject: [PATCH 19/28] vcat test #213 Co-authored-by: improbable22 --- test/tracker.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/tracker.jl b/test/tracker.jl index 01aa03de..7f84fdf9 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -72,6 +72,11 @@ end end @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) + + @testset "issue #213" begin + A, B, C = rand(2,2), rand(2,2), rand(2,2) + @test vcat(A, B, C |> param) == vcat(param.((A,B,C))...) + end end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) From 94bb064a0f3ddc70774d54440ee0df312c941dea Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:47:30 +0200 Subject: [PATCH 20/28] more tests of array promotion for concatenation # Conflicts: # test/tracker.jl --- src/tracker/array.jl | 5 ++--- test/tracker.jl | 33 ++++++++++++++++++++------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 4650e916..1139a903 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -96,10 +96,9 @@ end for f in [:vcat, :hcat] @eval begin Base.$f(a::TrackedArray...) = track($f, a...) - Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) - # assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector - Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) + # assumes there are other functions to match the more conservative signature without TrackedArray; ie `Base.$f(::Union{Matrix,Vector,RowVector}...)` + Base.$f(a::Union{TrackedArray,Matrix,Vector,RowVector}...) = track($f, a...) end end diff --git a/test/tracker.jl b/test/tracker.jl index 7f84fdf9..3185406a 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -29,13 +29,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) -function simplepromotioncheck(f, A, B) - r0 = f(A, B) - r1 = f(param(A), B) - r2 = f(A, param(B)) - r3 = f(param(A), param(B)) +function promotiontest(f, A, B, C) + r0 = f(A, B, C) + r1 = f(param(A), B, C) + if ndims(A) <= 2 + r2 = f(A, param(B), C) + r3 = f(A, B, param(C)) + else + r2 = r3 = f(A, param(B), param(C)) + end + r4 = f(param(A), param(B), param(C)) - r1 == r2 == r3 && r0 == Flux.data(r1) + r1 == r2 == r3 == r4 && r0 == Flux.data(r4) end @testset "concat" begin @@ -51,18 +56,15 @@ end @test gradtest(vcatf, rand(5)', rand(2,5)) end - @test simplepromotioncheck(vcat, rand(5), rand(5)) - @testset for hcatf in [hcat, cat2] @test gradtest(hcatf, rand(5), rand(5)) @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) @test gradtest(hcatf, rand(5)', rand(1,3)) @test gradtest(hcatf, rand(5), rand(5,2)) end - @test simplepromotioncheck(hcat, rand(5), rand(5)) - @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) @testset "cat($dim, ...)" for dim in 1:5 @@ -73,9 +75,14 @@ end @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) - @testset "issue #213" begin - A, B, C = rand(2,2), rand(2,2), rand(2,2) - @test vcat(A, B, C |> param) == vcat(param.((A,B,C))...) + @testset "promotiontest" begin + @test promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) + @test promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) + @test promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) + @test promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) + @testset "cat($dim, ...)" for dim in 1:5 + @test promotiontest((x...) -> cat(dim, x...), rand(3,4,5), rand(3,4,5), rand(3,4,5)) + end end end From 5fc61909563186e7db847ed67d84cf232e61c64b Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 14:57:32 +0200 Subject: [PATCH 21/28] RowVector tests --- src/tracker/array.jl | 20 +++++++++++++------ test/tracker.jl | 47 +++++++++++++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 1139a903..967ce8dd 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -95,10 +95,19 @@ end for f in [:vcat, :hcat] @eval begin - Base.$f(a::TrackedArray...) = track($f, a...) + # This section is a bit of a hack since julia doesn't have a standardised promotion mechanism for concatenation yet https://github.com/JuliaLang/julia/pull/20815 - # assumes there are other functions to match the more conservative signature without TrackedArray; ie `Base.$f(::Union{Matrix,Vector,RowVector}...)` - Base.$f(a::Union{TrackedArray,Matrix,Vector,RowVector}...) = track($f, a...) + # It should support tracked concatenation with rank ∈ (1,2) with a TrackedArray anywhere among the arguments + # This works as long as base has other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. + Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...) + + # It should support tracked concatenation with rank>2 if the TrackedArray is first + Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) + Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row + + # It should support tracked concatenation with rank>2 if the TrackedArray is second + Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) + Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, c::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b, c...) # resolves ambiguity introduced by previous row end end @@ -124,9 +133,8 @@ function back(::typeof(hcat), Δ, xs...) end end -Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...) -Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...) -Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...) +Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) +Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) function back(::typeof(cat), Δ, dims, Xs...) start = ntuple(i -> 0, Val{ndims(Δ)}) diff --git a/test/tracker.jl b/test/tracker.jl index 3185406a..434148f0 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -32,15 +32,19 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) function promotiontest(f, A, B, C) r0 = f(A, B, C) r1 = f(param(A), B, C) - if ndims(A) <= 2 - r2 = f(A, param(B), C) + r2 = f(A, param(B), C) + if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat] r3 = f(A, B, param(C)) else - r2 = r3 = f(A, param(B), param(C)) + @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved + r3 = r2 end r4 = f(param(A), param(B), param(C)) - r1 == r2 == r3 == r4 && r0 == Flux.data(r4) + @test !isa(r0, TrackedArray) + @test all(isa.([r1,r2,r3,r4], TrackedArray)) + @test r1 == r2 == r3 == r4 + @test r0 == Flux.data(r4) end @testset "concat" begin @@ -50,6 +54,7 @@ end @testset for vcatf in [vcat, cat1] @test gradtest(vcatf, rand(5), rand(3)) @test gradtest(vcatf, rand(5), rand(3), rand(8)) + @test gradtest(vcatf, rand(5)', rand(5)') @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) @test gradtest(vcatf, rand(5), rand(3,1)) @@ -58,31 +63,49 @@ end @testset for hcatf in [hcat, cat2] @test gradtest(hcatf, rand(5), rand(5)) + @test gradtest(hcatf, rand(5)', rand(5)') @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) @test gradtest(hcatf, rand(5)', rand(1,3)) @test gradtest(hcatf, rand(5), rand(5,2)) +end + + @testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + @test gradtest(catf, rand(5)) + @test gradtest(catf, rand(5)') + @test gradtest(catf, rand(2,5)) + @test gradtest(catf, rand(2,5,3)) end @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) - @testset "cat($dim, ...)" for dim in 1:5 + @testset "cat($dim, ...)" for dim in 3:5 catdim = (x...) -> cat(dim, x...) - @test gradtest(catdim, rand(5), rand(5)) + @test gradtest(catdim, rand(5), rand(5), rand(5)) @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) + @test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3)) end + @test !isa(vcat(rand(2)), TrackedArray) + @test !isa(hcat(rand(2)), TrackedArray) + @test !isa(cat(1,rand(2)), TrackedArray) + @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) @testset "promotiontest" begin - @test promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) - @test promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) - @test promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) - @test promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) - @testset "cat($dim, ...)" for dim in 1:5 - @test promotiontest((x...) -> cat(dim, x...), rand(3,4,5), rand(3,4,5), rand(3,4,5)) + @testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + promotiontest(fcat, rand(2), rand(2), rand(2)) + promotiontest(fcat, rand(2)', rand(2)', rand(2)') + promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2)) + promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2)) end + + promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) + promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) + promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) + promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) + promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2)) end end From cfbead633d578c2b4853aeb46c8f7fe70280eefc Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 3 May 2018 14:14:53 +0100 Subject: [PATCH 22/28] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 79e37c5f..f01df194 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

-[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) +[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.1240350.svg)](https://doi.org/10.5281/zenodo.1240350) Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. From 180c2433fe8e8d1ae748f933791d385aa74fb64d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 3 May 2018 18:34:03 +0100 Subject: [PATCH 23/28] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f01df194..b8fd360a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

-[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.1240350.svg)](https://doi.org/10.5281/zenodo.1240350) +[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](http://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602) Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. From 2d3f00da29216f9ff33ec5af1362723360dc07c4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 3 May 2018 18:50:28 +0100 Subject: [PATCH 24/28] Update README.md --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index b8fd360a..0baa74d4 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,18 @@ julia> Pkg.add("Flux") See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. +If you use Flux in research, please cite the following paper: + +``` +@article{innes:2018, + author = {Mike Innes}, + title = {Flux: Elegant Machine Learning with Julia}, + journal = {Journal of Open Source Software}, + year = {2018}, + doi = {10.21105/joss.00602}, +} +``` + ## Features Flux has powerful high-level features, and common architectures can be defined in a few lines. From b35b27be6e198b0e4bee1c5794f7b4a008e9cf00 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 4 May 2018 15:05:02 +0100 Subject: [PATCH 25/28] doc fix --- docs/src/models/layers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 379268b3..c2056bb4 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense -Conv2D +Conv ``` ## Recurrent Layers From b59161a41eaf43ce2fb7c9b6a39c671a3d680133 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Sat, 5 May 2018 17:15:18 +0100 Subject: [PATCH 26/28] export Tracker again --- src/Flux.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index 6cb92c50..7125630f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -17,7 +17,7 @@ using NNlib: @fix include("tracker/Tracker.jl") using .Tracker using .Tracker: data -export TrackedArray, TrackedVector, TrackedMatrix, param +export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise From ef9077d9fabc7a972fddd6afcbcd454cf2efae79 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Mon, 7 May 2018 13:03:52 +0100 Subject: [PATCH 27/28] style --- src/tracker/array.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 967ce8dd..e11296ab 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -95,19 +95,26 @@ end for f in [:vcat, :hcat] @eval begin - # This section is a bit of a hack since julia doesn't have a standardised promotion mechanism for concatenation yet https://github.com/JuliaLang/julia/pull/20815 + # This section is a bit of a hack since julia doesn't have a standardised + # promotion mechanism for concatenation yet + # https://github.com/JuliaLang/julia/pull/20815 - # It should support tracked concatenation with rank ∈ (1,2) with a TrackedArray anywhere among the arguments - # This works as long as base has other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. + # It should support tracked concatenation with rank ∈ (1,2) with a + # TrackedArray anywhere among the arguments This works as long as base has + # other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...) - # It should support tracked concatenation with rank>2 if the TrackedArray is first + # It should support tracked concatenation with rank>2 if the TrackedArray is + # first Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row - # It should support tracked concatenation with rank>2 if the TrackedArray is second + # It should support tracked concatenation with rank>2 if the TrackedArray is + # second Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) - Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, c::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b, c...) # resolves ambiguity introduced by previous row + Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, + c::Union{TrackedArray,Vector,RowVector,Matrix}...) = + track($f, a, b, c...) # resolves ambiguity introduced by previous row end end From 5685df169117c8b9169a5de05384ebd05bbea348 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Mon, 7 May 2018 16:12:55 +0100 Subject: [PATCH 28/28] tracker docs --- docs/make.jl | 2 + docs/src/internals/tracker.md | 156 ++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 docs/src/internals/tracker.md diff --git a/docs/make.jl b/docs/make.jl index d7f14d8e..ed6a8c8b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib], "One-Hot Encoding" => "data/onehot.md", "GPU Support" => "gpu.md", "Saving & Loading" => "saving.md", + "Internals" => + ["Backpropagation" => "internals/tracker.md"], "Community" => "community.md"]) deploydocs( diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md new file mode 100644 index 00000000..b9addc34 --- /dev/null +++ b/docs/src/internals/tracker.md @@ -0,0 +1,156 @@ +# Flux.Tracker + +Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module. + +```julia +julia> using Flux.Tracker +``` + +The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters: + +```julia +julia> W = param([1 2; 3 4]) +Tracked 2×2 Array{Float64,2}: + 1.0 2.0 + 3.0 4.0 + +julia> x = param([5, 6]) +Tracked 2-element Array{Float64,1}: + 5.0 + 6.0 + +julia> y = W*x +Tracked 2-element Array{Float64,1}: + 17.0 + 39.0 +``` + +The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays: + +```julia +julia> Tracker.back!(y, [1, -1]) + +julia> W.grad +2×2 Array{Float64,2}: + 5.0 6.0 +-5.0 -6.0 + +julia> x.grad +2-element Array{Float64,1}: + -2.0 + -2.0 +``` + +## Internals + +All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field. + +```julia +julia> x.tracker +Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0]) +``` + +The `Tracker` stores the value and gradient of a given object, which we've seen before. + +```julia +julia> x.tracker.data +2-element Array{Float64,1}: + 5.0 + 6.0 + +julia> x.tracker.grad +2-element Array{Float64,1}: + -2.0 + -2.0 +``` + +The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this: + +```julia +julia> Tracker.Call(+, 1, 2) +Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2)) +``` + +In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`. + +```julia +julia> y.tracker.f +Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0]))) +``` + +Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*). + +When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling + +```julia +Tracker.back(*, [1, -1], W, x) +``` + +which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters. + +## Custom Gradients + +We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`: + +```julia +julia> minus(a, b) = a - b +``` + +Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: + +```julia +julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus (generic function with 2 methods) +``` + +`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened: + +```julia +julia> a, b = param([6,5,4]), param([1,2,3]) +(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])) + +julia> c = minus(a, b) +Tracked 3-element Array{Float64,1}: + 5.0 + 3.0 + 1.0 + +julia> c.tracker.f +Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))) +``` + +Finally, we have to specify the gradient of `minus`. + +```julia +julia> Tracker.back(::typeof(minus), Δ, a, b) = + (Tracker.@back(a, Δ); Tracker.@back(b, -Δ)) +``` + +`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`. + +```julia +julia> Flux.back!(c, 1) + +julia> a.grad +3-element Array{Float64,1}: + 1.0 + 1.0 + 1.0 + +julia> b.grad +3-element Array{Float64,1}: + -1.0 + -1.0 + -1.0 +``` + +## Notes + +For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: + +```julia +minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b) +``` + +`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.