From 6743d52d08291e1a940cb0f922ed7be4c5561d07 Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Thu, 23 Aug 2018 21:34:11 +0800 Subject: [PATCH 01/25] Fix issue #354 --- src/layers/basic.jl | 3 ++- test/layers/basic.jl | 31 +++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 test/layers/basic.jl diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3e887472..123b041d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -75,10 +75,11 @@ end @treelike Dense -function (a::Dense)(x) +function (a::Dense)(x::AbstractArray) W, b, σ = a.W, a.b, a.σ σ.(W*x .+ b) end +(a::Dense)(x::Number) = a([x]) # prevent broadcasting of scalar function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) diff --git a/test/layers/basic.jl b/test/layers/basic.jl new file mode 100644 index 00000000..a37a4d12 --- /dev/null +++ b/test/layers/basic.jl @@ -0,0 +1,31 @@ +using Test, Random + + +@testset "basic" begin + @testset "Chain" begin + @test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax) + @test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)(randn(10)) + end + + @testset "Dense" begin + @test length(Dense(10, 5)(randn(10))) == 5 + @test_throws DimensionMismatch Dense(10, 5)(randn(1)) + Random.seed!(0) + @test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168) + Random.seed!(0) + @test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616 + -0.46724765394534323]) + + @test_throws DimensionMismatch Dense(10, 5)(1) + end + + @testset "Diagonal" begin + @test length(Flux.Diagonal(10)(randn(10))) == 10 + @test length(Flux.Diagonal(10)(1)) == 10 + @test length(Flux.Diagonal(10)(randn(1))) == 10 + @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) + Random.seed!(0) + @test all(Flux.Diagonal(2)(randn(2)).data .≈ [ 0.6791074260357777, + 0.8284134829000359]) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a6230f45..0b37d5b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,7 @@ insert!(LOAD_PATH, 2, "@v#.#") include("utils.jl") include("tracker.jl") +include("layers/basic.jl") include("layers/normalisation.jl") include("layers/stateless.jl") include("optimise.jl") From c9d6b5648f8ffc75c76faf7550c55fc49e2bab87 Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Thu, 23 Aug 2018 21:56:32 +0800 Subject: [PATCH 02/25] Fix issue #354 --- src/layers/basic.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 123b041d..0c2d3715 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -79,7 +79,6 @@ function (a::Dense)(x::AbstractArray) W, b, σ = a.W, a.b, a.σ σ.(W*x .+ b) end -(a::Dense)(x::Number) = a([x]) # prevent broadcasting of scalar function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) From 81e5f7c991edc8f548de306654a567891e2d33bb Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Thu, 23 Aug 2018 21:59:41 +0800 Subject: [PATCH 03/25] Update test/layers/basic.jl --- test/layers/basic.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index a37a4d12..0cb9ad78 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -3,20 +3,23 @@ using Test, Random @testset "basic" begin @testset "Chain" begin - @test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax) - @test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)(randn(10)) + @test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2))(randn(10)) + @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) + # numeric test should be put into testset of corresponding layer end @testset "Dense" begin @test length(Dense(10, 5)(randn(10))) == 5 @test_throws DimensionMismatch Dense(10, 5)(randn(1)) + @test_throws DimensionMismatch Dense(10, 5)(1) # avoid broadcasting + @test_throws DimensionMismatch Dense(10, 5).(randn(10)) # avoid broadcasting + Random.seed!(0) @test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168) Random.seed!(0) @test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616 -0.46724765394534323]) - @test_throws DimensionMismatch Dense(10, 5)(1) end @testset "Diagonal" begin From 4baf85bbe29360ec1d4b849e251c16960d53e388 Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Thu, 23 Aug 2018 22:29:03 +0800 Subject: [PATCH 04/25] update Testset of basic.jl --- test/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0cb9ad78..72051673 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -11,8 +11,8 @@ using Test, Random @testset "Dense" begin @test length(Dense(10, 5)(randn(10))) == 5 @test_throws DimensionMismatch Dense(10, 5)(randn(1)) - @test_throws DimensionMismatch Dense(10, 5)(1) # avoid broadcasting - @test_throws DimensionMismatch Dense(10, 5).(randn(10)) # avoid broadcasting + @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting + @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting Random.seed!(0) @test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168) From b35664c59f27f568a96d350e6a4504c70964bf2b Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Sat, 25 Aug 2018 16:30:46 +0800 Subject: [PATCH 05/25] Update testsets --- test/layers/basic.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 72051673..dff2be0b 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -3,7 +3,7 @@ using Test, Random @testset "basic" begin @testset "Chain" begin - @test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2))(randn(10)) + @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10)) @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) # numeric test should be put into testset of corresponding layer end @@ -14,11 +14,10 @@ using Test, Random @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting - Random.seed!(0) - @test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168) - Random.seed!(0) - @test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616 - -0.46724765394534323]) + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == [10] + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == [10 10] + @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == [10; 10] + @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] end @@ -27,8 +26,9 @@ using Test, Random @test length(Flux.Diagonal(10)(1)) == 10 @test length(Flux.Diagonal(10)(randn(1))) == 10 @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) - Random.seed!(0) - @test all(Flux.Diagonal(2)(randn(2)).data .≈ [ 0.6791074260357777, - 0.8284134829000359]) + + @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] + @test Flux.Diagonal(2)([1,2]) == [1,2] + @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] end end From 6bbed07e96048503fad3dfd9dd3000b37781506c Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 7 Sep 2018 02:05:03 +0100 Subject: [PATCH 06/25] enable nested broadcast --- src/tracker/array.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index ffa3a89e..16f91d22 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -351,9 +351,9 @@ end eltype(y) <: Real || return y eltype(y) == Bool && return y function back(Δ) - Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N)) - dxs = unbroadcast.(args, Δargs) - return nobacksies(:broadcast, dxs) + Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N)) + dxs = map(unbroadcast, args, Δargs) + return dxs end # So we can return non-tracked arrays track(Call(back, tracker.(args)), y) From 8b9a98ed0129efb87a8a1f4d63e5c49b33c85869 Mon Sep 17 00:00:00 2001 From: Sambit Kumar Dash Date: Tue, 11 Sep 2018 18:58:07 +0530 Subject: [PATCH 07/25] The sample gradient should not use the softdash While softdash is a very natural and mathematical way of representation, it can be very easily confused with the apostrophe used for LinAlg adjoint. Not worth and unnecessary confusion in a first example of the code. --- docs/src/models/basics.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 88fa0a05..a0a39ab5 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -10,14 +10,14 @@ using Flux.Tracker f(x) = 3x^2 + 2x + 1 # df/dx = 6x + 2 -f′(x) = Tracker.gradient(f, x)[1] +df(x) = Tracker.gradient(f, x)[1] -f′(2) # 14.0 (tracked) +df(2) # 14.0 (tracked) # d²f/dx² = 6 -f′′(x) = Tracker.gradient(f′, x)[1] +d2f(x) = Tracker.gradient(df, x)[1] -f′′(2) # 6.0 (tracked) +d2f(2) # 6.0 (tracked) ``` (We'll learn more about why these numbers show up as `(tracked)` below.) From d797999fc5353a2b4973b606872f7dbd1bb86af6 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 14 Sep 2018 18:10:24 +0100 Subject: [PATCH 08/25] fix sentiment model --- src/data/sentiment.jl | 9 +++++---- test/data.jl | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl index a58cd9d4..56c9e8ea 100644 --- a/src/data/sentiment.jl +++ b/src/data/sentiment.jl @@ -4,7 +4,7 @@ using ZipFile using ..Data: deps function load() - isfile(deps("sentiment.zip")) || return + isfile(deps("sentiment.zip")) && return @info "Downloading sentiment treebank dataset" download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", deps("sentiment.zip")) @@ -26,9 +26,10 @@ totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b)) totree(t::Expr) = totree_(t.args...) function parsetree(s) - s = replace(s, r"\$", s -> "\\\$") - s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"") - s = replace(s, " ", ", ") + s = replace(s, "\\" => "") + s = replace(s, "\$" => "\\\$") + s = replace(s, r"[^ \n\(\)]+" => s -> "\"$s\"") + s = replace(s, " " => ", ") return totree(Meta.parse(s)) end diff --git a/test/data.jl b/test/data.jl index 7a27c651..9c2901cb 100644 --- a/test/data.jl +++ b/test/data.jl @@ -9,3 +9,5 @@ using Test @test MNIST.images()[1] isa Matrix @test MNIST.labels() isa Vector{Int64} + +@test Data.Sentiment.train() isa Vector{Data.Tree{Any}} From e803117e2591b9dc5a074bfacb49ca1aa72295dd Mon Sep 17 00:00:00 2001 From: Isaac Tay Date: Sat, 15 Sep 2018 16:45:04 +0800 Subject: [PATCH 09/25] updated loadparams! function --- src/treelike.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/treelike.jl b/src/treelike.jl index 3d83d448..9b3518d3 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -54,7 +54,7 @@ function loadparams!(m, xs) for (p, x) in zip(params(m), xs) size(p) == size(x) || error("Expected param size $(size(p)), got $(size(x))") - copy!(data(p), data(x)) + copyto!(data(p), data(x)) end end From d1318535878e9bc9edb3157f0cda2442a1680c14 Mon Sep 17 00:00:00 2001 From: Alex Bird Date: Wed, 19 Sep 2018 13:08:30 +0100 Subject: [PATCH 10/25] add inv/ldivide/rdivide + test --- src/tracker/array.jl | 37 +++++++++++++++++++++++++++++++++++++ test/tracker.jl | 5 +++++ 2 files changed, 42 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 6d3c3b3f..3d9836d0 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,6 +1,8 @@ import Base: * import LinearAlgebra +import LinearAlgebra: inv, \, / + using Statistics using LinearAlgebra: Transpose, Adjoint, diagm, diag @@ -205,6 +207,41 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b) Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b) Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) + +inv(A::TrackedArray) = Tracker.track(inv, A) +@grad function inv(A) + return inv(Tracker.data(A)), function (Δ) + Ainv = inv(A) + ∇A = - Ainv' * Δ * Ainv' + return (∇A, ) + end +end + +# (/) rdivide +A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B) +A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B) +A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B) +@grad function (A / B) + return Tracker.data(A) / Tracker.data(B), function (Δ) + Binv = inv(B) + ∇B = - Binv' * A' * Δ * Binv' + return (Δ * Binv', ∇B) + end +end + +# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity) +A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B) +A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B) +A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B) +@grad function (A \ B) + return Tracker.data(A) \ Tracker.data(B), function (Δ) + Ainv = inv(A) + ∇A = - Ainv' * Δ * B' * Ainv' + return (∇A, Ainv' * Δ) + end +end + + # Reductions Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims) diff --git a/test/tracker.jl b/test/tracker.jl index 9a4cb793..a4772f2e 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -129,6 +129,11 @@ end @test gradtest(f-> Matrix(Diagonal(f)), rand(3)) +@test gradtest(W -> inv(log.(W * W)), (5,5)) +@test gradtest((A, B) -> A / B , (1,5), (5,5)) +@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5)) +@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5)) + @testset "mean" begin @test gradtest(mean, rand(2, 3)) From 079614adb21dbbce878b5b1d0fb065332bb6651f Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 19 Sep 2018 16:45:11 +0100 Subject: [PATCH 11/25] Fix typo --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 4fc58f72..4b5668a1 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -3,7 +3,7 @@ Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles: * **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast. -* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own. +* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When in doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own. * **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models. ## Installation From b20ae0546b7d77c48558d0555c2adc648386c5a3 Mon Sep 17 00:00:00 2001 From: JohnnyChen Date: Wed, 26 Sep 2018 20:30:13 +0800 Subject: [PATCH 12/25] rebase to pass the test --- test/layers/basic.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index dff2be0b..f9015068 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -1,6 +1,5 @@ using Test, Random - @testset "basic" begin @testset "Chain" begin @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10)) From 3bf18347e0f74924660237b93b7c8e39464216f6 Mon Sep 17 00:00:00 2001 From: JohnnyChen Date: Wed, 26 Sep 2018 22:03:38 +0800 Subject: [PATCH 13/25] Fix dimensional error in test --- test/layers/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index f9015068..b8d9efd1 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -13,9 +13,9 @@ using Test, Random @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == [10] - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == [10 10] - @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == [10; 10] + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1) + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2) + @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] end From d25e05d9eed5cde043a609bf6aca63bc545ee6b5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 27 Sep 2018 10:40:44 +0200 Subject: [PATCH 14/25] evaluate both 2-ary DiffRules only when needed --- src/tracker/scalar.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 81ccb9a3..1b6098fb 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -63,7 +63,9 @@ for (M, f, arity) in DiffRules.diffrules() da, db = DiffRules.diffrule(M, f, :a, :b) f = :($M.$f) @eval begin - @grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, zero(b)) + @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (zero(a), Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b) From aff4c7898e9808a43f434d401448f3f88fc99d90 Mon Sep 17 00:00:00 2001 From: Christopher Murphy <6396338+c-p-murphy@users.noreply.github.com> Date: Mon, 1 Oct 2018 15:26:26 -0400 Subject: [PATCH 15/25] add FashionMNIST --- src/data/Data.jl | 3 + src/data/fashion-mnist.jl | 115 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) create mode 100644 src/data/fashion-mnist.jl diff --git a/src/data/Data.jl b/src/data/Data.jl index d5b5f38d..ddf0624b 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -13,6 +13,9 @@ end include("mnist.jl") export MNIST +include("fashion-mnist.jl") +export FashionMNIST + include("cmudict.jl") using .CMUDict diff --git a/src/data/fashion-mnist.jl b/src/data/fashion-mnist.jl new file mode 100644 index 00000000..4e697672 --- /dev/null +++ b/src/data/fashion-mnist.jl @@ -0,0 +1,115 @@ +module FashionMNIST + +using CodecZlib, Colors + +const Gray = Colors.Gray{Colors.N0f8} + +const dir = joinpath(@__DIR__, "../../deps/fashion-mnist") + +function gzopen(f, file) + open(file) do io + f(GzipDecompressorStream(io)) + end +end + +function load() + mkpath(dir) + cd(dir) do + for file in ["train-images-idx3-ubyte", + "train-labels-idx1-ubyte", + "t10k-images-idx3-ubyte", + "t10k-labels-idx1-ubyte"] + isfile(file) && continue + @info "Downloading Fashion-MNIST dataset" + download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz") + open(file, "w") do io + write(io, gzopen(read, "$file.gz")) + end + end + end +end + +const IMAGEOFFSET = 16 +const LABELOFFSET = 8 + +const NROWS = 28 +const NCOLS = 28 + +const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte") +const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte") +const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte") +const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte") + +function imageheader(io::IO) + magic_number = bswap(read(io, UInt32)) + total_items = bswap(read(io, UInt32)) + nrows = bswap(read(io, UInt32)) + ncols = bswap(read(io, UInt32)) + return magic_number, Int(total_items), Int(nrows), Int(ncols) +end + +function labelheader(io::IO) + magic_number = bswap(read(io, UInt32)) + total_items = bswap(read(io, UInt32)) + return magic_number, Int(total_items) +end + +function rawimage(io::IO) + img = Array{Gray}(undef, NCOLS, NROWS) + for i in 1:NCOLS, j in 1:NROWS + img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8)) + end + return img +end + +function rawimage(io::IO, index::Integer) + seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1)) + return rawimage(io) +end + +rawlabel(io::IO) = Int(read(io, UInt8)) + +function rawlabel(io::IO, index::Integer) + seek(io, LABELOFFSET + (index - 1)) + return rawlabel(io) +end + +getfeatures(io::IO, index::Integer) = vec(getimage(io, index)) + +""" + images() + images(:test) + +Load the MNIST images. + +Each image is a 28×28 array of `Gray` colour values (see Colors.jl). + +Returns the 60,000 training images by default; pass `:test` to retreive the +10,000 test images. +""" +function images(set = :train) + load() + io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) + _, N, nrows, ncols = imageheader(io) + [rawimage(io) for _ in 1:N] +end + +""" + labels() + labels(:test) + +Load the labels corresponding to each of the images returned from `images()`. +Each label is a number from 0-9. + +Returns the 60,000 training labels by default; pass `:test` to retreive the +10,000 test labels. +""" +function labels(set = :train) + load() + io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) + _, N = labelheader(io) + [rawlabel(io) for _ = 1:N] +end + + +end From 7e67bf06e1567bf7a8e802c2967d972fb3e66c6d Mon Sep 17 00:00:00 2001 From: Christopher Murphy <6396338+c-p-murphy@users.noreply.github.com> Date: Tue, 2 Oct 2018 15:00:45 -0400 Subject: [PATCH 16/25] update tests --- test/data.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/data.jl b/test/data.jl index 9c2901cb..a73d1ec3 100644 --- a/test/data.jl +++ b/test/data.jl @@ -10,4 +10,7 @@ using Test @test MNIST.images()[1] isa Matrix @test MNIST.labels() isa Vector{Int64} +@test FashionMNIST.images()[1] isa Matrix +@test FashionMNIST.labels() isa Vector{Int64} + @test Data.Sentiment.train() isa Vector{Data.Tree{Any}} From 95d72d7f793d316ab180a2fe034ddce47ba7bc55 Mon Sep 17 00:00:00 2001 From: Christopher Murphy <6396338+c-p-murphy@users.noreply.github.com> Date: Tue, 2 Oct 2018 15:31:44 -0400 Subject: [PATCH 17/25] update comments --- src/data/fashion-mnist.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/fashion-mnist.jl b/src/data/fashion-mnist.jl index 4e697672..d608d8bb 100644 --- a/src/data/fashion-mnist.jl +++ b/src/data/fashion-mnist.jl @@ -80,7 +80,7 @@ getfeatures(io::IO, index::Integer) = vec(getimage(io, index)) images() images(:test) -Load the MNIST images. +Load the Fashion-MNIST images. Each image is a 28×28 array of `Gray` colour values (see Colors.jl). From 252e34e173ea3e05a198fc37969d2542eaab8526 Mon Sep 17 00:00:00 2001 From: Robert Luciani Date: Tue, 2 Oct 2018 21:39:00 +0200 Subject: [PATCH 18/25] 1.0+ updates - indices to axes, Vector init with undef --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index c53f7864..6a970f0b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,7 +24,7 @@ julia> chunk(1:10, 3) """ chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n))) -batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i) +batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i) """ frequencies(xs) @@ -66,7 +66,7 @@ julia> batch([[1,2,3],[4,5,6]]) function batch(xs) data = first(xs) isa AbstractArray ? similar(first(xs), size(first(xs))..., length(xs)) : - Vector{eltype(xs)}(length(xs)) + Vector{eltype(xs)}(undef, length(xs)) for (i, x) in enumerate(xs) data[batchindex(data, i)...] = x end From fe6793fde5b40430999c30d207570ce85d4d3fbc Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 3 Oct 2018 11:45:29 +0100 Subject: [PATCH 19/25] closes #411 --- src/layers/recurrent.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 3b40af04..40cd322a 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -148,7 +148,7 @@ Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") """ - LSTM(in::Integer, out::Integer, σ = tanh) + LSTM(in::Integer, out::Integer) Long Short Term Memory recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. @@ -189,7 +189,7 @@ Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") """ - GRU(in::Integer, out::Integer, σ = tanh) + GRU(in::Integer, out::Integer) Gated Recurrent Unit layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. From 73a526b1de465b0ad893d46fce09c0536d5a0d8b Mon Sep 17 00:00:00 2001 From: Christopher Murphy <6396338+c-p-murphy@users.noreply.github.com> Date: Wed, 3 Oct 2018 12:40:24 -0400 Subject: [PATCH 20/25] reuse utils from mnist.jl --- src/data/fashion-mnist.jl | 53 +-------------------------------------- 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/src/data/fashion-mnist.jl b/src/data/fashion-mnist.jl index d608d8bb..e4510b47 100644 --- a/src/data/fashion-mnist.jl +++ b/src/data/fashion-mnist.jl @@ -1,17 +1,9 @@ module FashionMNIST -using CodecZlib, Colors - -const Gray = Colors.Gray{Colors.N0f8} +using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel const dir = joinpath(@__DIR__, "../../deps/fashion-mnist") -function gzopen(f, file) - open(file) do io - f(GzipDecompressorStream(io)) - end -end - function load() mkpath(dir) cd(dir) do @@ -29,53 +21,11 @@ function load() end end -const IMAGEOFFSET = 16 -const LABELOFFSET = 8 - -const NROWS = 28 -const NCOLS = 28 - const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte") const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte") const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte") const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte") -function imageheader(io::IO) - magic_number = bswap(read(io, UInt32)) - total_items = bswap(read(io, UInt32)) - nrows = bswap(read(io, UInt32)) - ncols = bswap(read(io, UInt32)) - return magic_number, Int(total_items), Int(nrows), Int(ncols) -end - -function labelheader(io::IO) - magic_number = bswap(read(io, UInt32)) - total_items = bswap(read(io, UInt32)) - return magic_number, Int(total_items) -end - -function rawimage(io::IO) - img = Array{Gray}(undef, NCOLS, NROWS) - for i in 1:NCOLS, j in 1:NROWS - img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8)) - end - return img -end - -function rawimage(io::IO, index::Integer) - seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1)) - return rawimage(io) -end - -rawlabel(io::IO) = Int(read(io, UInt8)) - -function rawlabel(io::IO, index::Integer) - seek(io, LABELOFFSET + (index - 1)) - return rawlabel(io) -end - -getfeatures(io::IO, index::Integer) = vec(getimage(io, index)) - """ images() images(:test) @@ -111,5 +61,4 @@ function labels(set = :train) [rawlabel(io) for _ = 1:N] end - end From 2ff54ee0fd85b1641279cfcad041331986b34604 Mon Sep 17 00:00:00 2001 From: Tejan Karmali Date: Thu, 4 Oct 2018 11:31:29 -0400 Subject: [PATCH 21/25] cudnn_available() update --- src/cuda/cuda.jl | 2 +- test/cuda/cuda.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index fe36bf5d..15126aca 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -2,6 +2,6 @@ module CUDA using ..CuArrays -CuArrays.cudnn_available() && include("cudnn.jl") +CuArrays.libcudnn != nothing && include("cudnn.jl") end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 16f90e89..1f54d1b9 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -36,4 +36,4 @@ Flux.back!(sum(l)) end -CuArrays.cudnn_available() && include("cudnn.jl") +CuArrays.libcudnn != nothing && include("cudnn.jl") From 69afdd61a672c8d92a8a121197b5e408e16f6279 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 13:59:58 +0100 Subject: [PATCH 22/25] avoid a warning --- src/tracker/Tracker.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 190837ab..94f9a94c 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -108,10 +108,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs)) param(x::TrackedReal) = track(identity, x) param(x::TrackedArray) = track(identity, x) -import NNlib.cudata import Adapt.adapt -cudata(x::TrackedArray) = data(x) adapt(T, xs::TrackedArray) = param(adapt(T, data(xs))) end From 61fb6cdf053da66f29f1afb3161f8a86434b0572 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 14:02:00 +0100 Subject: [PATCH 23/25] jit macro --- src/utils.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 6a970f0b..74d479bd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -153,3 +153,18 @@ function jacobian(m,x) end J' end + +""" + @jit ... + +The `@jit` annotation can be applied to any code, and the code will be compiled +for performance. + + @jit f(x) = @jit(x) + @jit(x) + +Note that compilation happens regardless of the `@jit` macro, so it should only +be used for aesthetic purposes, or by recovering Python users. +""" +macro jit(ex) + esc(ex) +end From c6740c5cdd735e91869cf7615e711cfa47679f8f Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 14:14:24 +0100 Subject: [PATCH 24/25] fix unbroadcast --- src/cuda/cudnn.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index f033595a..61609b0d 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -328,7 +328,7 @@ end h_ = hBatch(x, data(h)) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) - nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db)) + nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) end end @@ -342,7 +342,7 @@ end dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) nobacksies(:RNN, - (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), + (dx, unbroadcast(h, dh), unbroadcast(c, dc), transpose(dWi), transpose(dWh), db)) end end From 3b391a1af6da614e59e6f48e30af377d6dd0c9b5 Mon Sep 17 00:00:00 2001 From: Proyag Date: Fri, 5 Oct 2018 14:47:06 +0100 Subject: [PATCH 25/25] #389