From c59b820bed91a214126f5d3b66461ee44c855be0 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Mon, 4 Dec 2017 23:47:03 -0800 Subject: [PATCH 01/23] Add glorot (Xavier) initialization Set default `Dense` and `RNN` inits to `glorot_uniform()` for `W`, `zeros` for `b`. --- src/layers/basic.jl | 6 ++++-- src/layers/recurrent.jl | 12 ++++++------ src/utils.jl | 2 ++ test/utils.jl | 25 ++++++++++++++++++++++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index aa101c43..9f458ab4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -63,8 +63,10 @@ struct Dense{F,S,T} b::T end -Dense(in::Integer, out::Integer, σ = identity; init = initn) = - Dense(σ, param(init(out, in)), param(init(out))) +function Dense(in::Integer, out::Integer, σ = identity; + initW = glorot_uniform, initb = zeros) + return Dense(σ, param(initW(out, in)), param(initb(out))) +end treelike(Dense) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 599776ce..781bd405 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -79,8 +79,8 @@ struct RNNCell{D,V} h::V end -RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) = - RNNCell(Dense(in+out, out, σ, init = init), param(init(out))) +RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) = + RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out))) function (m::RNNCell)(h, x) h = m.d(combine(x, h)) @@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V} h::V; c::V end -function LSTMCell(in, out; init = initn) - cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]..., - Dense(in+out, out, tanh, init = init), - param(init(out)), param(init(out))) +function LSTMCell(in, out; initW = glorot_uniform, initb = zeros) + cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]..., + Dense(in+out, out, tanh, initW = initW, initb = initb), + param(initW(out)), param(initW(out))) cell.forget.b.data .= 1 return cell end diff --git a/src/utils.jl b/src/utils.jl index f822c111..944d35bf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,8 @@ # Arrays initn(dims...) = randn(dims...)/100 +glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims))) +glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims))) flatten(xs) = reshape(xs, size(xs, 1), :) diff --git a/test/utils.jl b/test/utils.jl index 7638fd2a..1c313a3d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using Flux: throttle +using Flux: throttle, initn, glorot_uniform, glorot_normal @testset "Throttle" begin @testset "default behaviour" begin @@ -47,3 +47,26 @@ using Flux: throttle @test a == [1, 3] end end + +@testset "Initialization" begin + # Set random seed so that these tests don't fail randomly + srand(0) + # initn() should yield a kernel with stddev ~= 1e-2 + v = initn(10, 10) + @test std(v) > 0.9*1e-2 + @test std(v) < 1.1*1e-2 + + # glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)), + # and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = glorot_uniform(n_in, n_out) + @test minimum(v) > -1.1*sqrt(6/(n_in + n_out)) + @test minimum(v) < -0.9*sqrt(6/(n_in + n_out)) + @test maximum(v) > 0.9*sqrt(6/(n_in + n_out)) + @test maximum(v) < 1.1*sqrt(6/(n_in + n_out)) + + v = glorot_normal(n_in, n_out) + @test std(v) > 0.9*sqrt(2/(n_in + n_out)) + @test std(v) < 1.1*sqrt(2/(n_in + n_out)) + end +end \ No newline at end of file From 41446d547fa5e1a80d6c928fe4a24ce8ae280dc3 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Tue, 5 Dec 2017 15:38:15 -0800 Subject: [PATCH 02/23] Add `weighted_crossentropy` for imbalanced classification problems --- src/layers/stateless.jl | 11 +++++++++-- test/layers/stateless.jl | 26 ++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 test/layers/stateless.jl diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index edbdec58..8d675735 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,8 +4,15 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) + return -sum(y .* log_fast.(ŷ)) / size(y, 2) +end + +function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat) + return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2) +end + + @deprecate logloss(x, y) crossentropy(x, y) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl new file mode 100644 index 00000000..b7a42841 --- /dev/null +++ b/test/layers/stateless.jl @@ -0,0 +1,26 @@ +using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy + +@testset "losses" begin + # First, regression-style y's + y = [1, 1, 0, 0] + y_hat = [.9, .1, .1, .9] + + @testset "mse" begin + @test mse(y_hat, y) ≈ (.1^2 + .9^2)/2 + end + + # Now onehot y's + y = onehotbatch([1, 1, 0, 0], 0:1) + y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]' + y_logloss = 1.203972804325936 + + @testset "crossentropy" begin + @test crossentropy(y_hat, y) ≈ y_logloss + end + + @testset "weighted_crossentropy" begin + @test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss + @test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2 + @test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index efd1a462..5c6ba549 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") include("layers/normalisation.jl") +include("layers/stateless.jl") end From 385dee9d16e29bb37cc4083ffd5bad736a70f520 Mon Sep 17 00:00:00 2001 From: baggepinnen Date: Fri, 8 Dec 2017 14:46:12 +0100 Subject: [PATCH 03/23] Add jacobian function --- src/Flux.jl | 2 +- src/utils.jl | 19 +++++++++++++++++++ test/utils.jl | 9 +++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index df4b1636..2c79e426 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,7 +9,7 @@ using Lazy: @forward export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, SGD, ADAM, Momentum, Nesterov, - param, params, mapleaves + param, params, mapleaves, jacobian using NNlib export σ, relu, leakyrelu, elu, swish, softmax diff --git a/src/utils.jl b/src/utils.jl index f822c111..755b54e9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -120,3 +120,22 @@ function throttle(f, timeout; leading=true, trailing=false) nothing end end + +""" + J = jacobian(m,x) + +Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])` +""" +function jacobian(m,x) + xp = param(x) + y = m(xp) + k = length(y) + n = length(x) + J = Matrix{eltype(x)}(n,k) + for i = 1:k + Flux.back!(y[i]) # Populate gradient accumulator + J[:,i] = xp.grad + xp.grad .*= 0 # Reset gradient accumulator + end + J' +end diff --git a/test/utils.jl b/test/utils.jl index 7638fd2a..abee0f24 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -47,3 +47,12 @@ using Flux: throttle @test a == [1, 3] end end + +@testset "Jacobian" begin + A = param(randn(2,2)) + x = randn(2) + m(x) = A*x + y = m(x) + J = jacobian(m,x) + @test J ≈ A.data +end From b7b6c975bc91c3a9c531178ae7715f216698cdd4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 12 Dec 2017 17:07:39 +0000 Subject: [PATCH 04/23] fixes #110 --- src/tracker/lib.jl | 9 +++++++-- test/tracker.jl | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 5065a40d..ab250e39 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -70,7 +70,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) = # BLAS -for f in :[*, Ac_mul_B].args +for f in :[*, Ac_mul_B, A_mul_Bc].args @eval begin import Base.$f $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) @@ -94,7 +94,12 @@ end function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) @back(a, A_mul_Bt(Δ, data(b))') - @back(b, *(data(a), Δ)) + @back(b, data(a)*Δ) +end + +function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) + @back(a, Δ * data(b)) + @back(b, At_mul_B(data(a), Δ)') end # Fast path for matrix-vector diff --git a/test/tracker.jl b/test/tracker.jl index 81a72566..7d9ef4f5 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -10,6 +10,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) +@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) From 29787eba452a0e12e7c152fe7ded67393f18a8b7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 12 Dec 2017 17:23:15 +0000 Subject: [PATCH 05/23] fixes #114 --- src/tracker/lib.jl | 9 +++++++++ test/tracker.jl | 2 ++ 2 files changed, 11 insertions(+) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index ab250e39..f3221bd8 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -58,6 +58,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) +LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) + +function back(::typeof(dot), Δ, xs, ys) + @back(xs, Δ.*ys) + @back(ys, Δ.*xs) +end + # Hacks to get std working Base.std(x::TrackedArray; mean = Base.mean(x)) = sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) diff --git a/test/tracker.jl b/test/tracker.jl index 7d9ef4f5..ac031915 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -38,6 +38,8 @@ end @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5)) +@test gradtest((x, y) -> x .* y, rand(5), rand(5)) + @test gradtest(rand(5)) do x y = x.^2 2y + x From e3a688e70646cc832e6a69acdb1efe6cdbe5eb36 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 13 Dec 2017 15:27:15 +0000 Subject: [PATCH 06/23] use kwarg --- src/layers/stateless.jl | 10 ++-------- test/layers/stateless.jl | 8 ++++---- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 8d675735..63c40cb8 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,16 +4,10 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) - return -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) + return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2) end -function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat) - return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2) -end - - - @deprecate logloss(x, y) crossentropy(x, y) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index b7a42841..23304eb1 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,4 +1,4 @@ -using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy +using Flux: onehotbatch, mse, crossentropy @testset "losses" begin # First, regression-style y's @@ -19,8 +19,8 @@ using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy end @testset "weighted_crossentropy" begin - @test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss - @test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2 - @test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199 + @test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss + @test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2 + @test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199 end end From 23096824d5e3eaf00688d8a91a4c881b3e3e7898 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 13 Dec 2017 17:29:32 +0000 Subject: [PATCH 07/23] import jacobian --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 34762adf..7a00b57d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using Flux: throttle, initn, glorot_uniform, glorot_normal +using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian @testset "Throttle" begin @testset "default behaviour" begin From 5b97d2ba04dec8c31d600518bd404adca0a454b9 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 13 Dec 2017 18:24:56 +0000 Subject: [PATCH 08/23] closes #127 --- src/optimise/train.jl | 23 ++++++++++++++++------- src/utils.jl | 7 ++++--- test/optimise.jl | 12 ++++++++++++ 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 0809e86b..31812fa0 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,15 +1,24 @@ using Juno -using Flux.Tracker: back! +using Flux.Tracker: back!, value runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) """ - train!(loss, data, opt; cb = () -> ()) + train!(loss, data, opt) For each datapoint `d` in `data` computes the gradient of `loss(d...)` through -backpropagation and calls the optimizer `opt` and the callback `cb` -(i.e. `opt()` and `cb()`). +backpropagation and calls the optimizer `opt`. + +Takes a callback as keyword argument `cb`. For example, this will print "training" +every 10 seconds: + +```julia +Flux.train!(loss, data, opt, + cb = throttle(() -> println("training"), 10)) +``` + +The callback can return `:stop` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ @@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l.data[]) && error("Loss is Inf") - isnan(l.data[]) && error("Loss is NaN") + isinf(value(l)) && error("Loss is Inf") + isnan(value(l)) && error("Loss is NaN") back!(l) opt() - cb() + cb() == :stop && break end end diff --git a/src/utils.jl b/src/utils.jl index 9a03ae4f..afe926d9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -95,13 +95,14 @@ but if you'd like to disable the execution on the leading edge, pass function throttle(f, timeout; leading=true, trailing=false) cooldown = true later = nothing + result = nothing function throttled(args...; kwargs...) yield() if cooldown if leading - f(args...; kwargs...) + result = f(args...; kwargs...) else later = () -> f(args...; kwargs...) end @@ -116,10 +117,10 @@ function throttle(f, timeout; leading=true, trailing=false) cooldown = true end elseif trailing - later = () -> f(args...; kwargs...) + later = () -> (result = f(args...; kwargs...)) end - nothing + return result end end diff --git a/test/optimise.jl b/test/optimise.jl index 526f0534..66c50037 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -15,3 +15,15 @@ using Flux.Tracker @test Flux.mse(w, w′) < 0.01 end end + +@testset "Training Loop" begin + i = 0 + l = param(1) + + Flux.train!(() -> (sleep(0.1); i += 1; l), + Iterators.repeated((), 100), + ()->(), + cb = Flux.throttle(() -> (i > 3 && :stop), 1)) + + @test 3 < i < 50 +end From d949b31aa5548e4fb87625def9662765a92879c9 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 14 Dec 2017 18:48:38 +0000 Subject: [PATCH 09/23] conv gradient --- src/tracker/lib.jl | 12 +++++++++++- src/tracker/numeric.jl | 2 +- test/tracker.jl | 3 +++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index f3221bd8..cac8d7d1 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -123,12 +123,22 @@ end # NNlib -import NNlib: softmax, ∇softmax +using NNlib +import NNlib: softmax, ∇softmax, conv2d softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) +conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w)) +conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w)) +conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w)) + +function back(::typeof(conv2d), Δ, x, w) + @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ)) + @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ)) +end + # Broadcasting using ForwardDiff: Dual, partials diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 68211aa3..cbcd3ad8 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...) return grads end -gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6)) +gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5)) diff --git a/test/tracker.jl b/test/tracker.jl index ac031915..a3d9563b 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,5 +1,6 @@ using Flux.Tracker, Base.Test, NNlib using Flux.Tracker: gradcheck +using NNlib gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...) gradtest(f, dims...) = gradtest(f, rand.(dims)...) @@ -45,4 +46,6 @@ end 2y + x end +@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2)) + end #testset From 0bf22dfb8ea37332543867dc00431e22313f61b5 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 02:29:14 +0000 Subject: [PATCH 10/23] pool gradients --- src/tracker/back.jl | 9 +++++---- src/tracker/lib.jl | 10 +++++++++- test/tracker.jl | 2 ++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 39810069..d6a48409 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -12,16 +12,17 @@ function scan(x::TrackedArray) return end -back(c::Call, Δ) = back(c.func, Δ, c.args...) -back(::Call{Void}, Δ) = nothing +back_(f, y, args...) = back(f, args...) +back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) +back_(::Call{Void}, y, Δ) = nothing function back(x::TrackedArray, Δ) ref = x.ref -= 1 if isdefined(x, :grad) x.grad .+= Δ - ref == 0 && back(x.f, x.grad) + ref == 0 && back_(x.f, x.data, x.grad) else - ref == 0 && back(x.f, Δ) + ref == 0 && back_(x.f, x.data, Δ) end return end diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index cac8d7d1..57474933 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -124,7 +124,7 @@ end # NNlib using NNlib -import NNlib: softmax, ∇softmax, conv2d +import NNlib: softmax, ∇softmax, conv2d, pool softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) @@ -139,6 +139,14 @@ function back(::typeof(conv2d), Δ, x, w) @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ)) end +_pool(x, k, mode) = pool(x, window = k, mode = mode) + +pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) = + TrackedArray(Call(_pool, x, window, mode)) + +back_(::typeof(_pool), y, Δ, x, k, mode) = + back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode)) + # Broadcasting using ForwardDiff: Dual, partials diff --git a/test/tracker.jl b/test/tracker.jl index a3d9563b..dc11420b 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -47,5 +47,7 @@ end end @test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2)) +@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2)) +@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2)) end #testset From 9d0dd9fb7e2a49c23783ed6b47a10eed97865f74 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 13:22:57 +0000 Subject: [PATCH 11/23] layer wip --- src/Flux.jl | 6 ++++-- src/layers/conv.jl | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 src/layers/conv.jl diff --git a/src/Flux.jl b/src/Flux.jl index 526d6bb8..2acdb177 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,13 +7,14 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, +export Chain, Dense, RNN, LSTM, Conv2D, Dropout, LayerNorm, BatchNorm, SGD, ADAM, Momentum, Nesterov, AMSGrad, param, params, mapleaves using NNlib -export σ, sigmoid, relu, leakyrelu, elu, swish, softmax +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, + conv2d, maxpool2d, avgpool2d include("tracker/Tracker.jl") using .Tracker @@ -27,6 +28,7 @@ include("treelike.jl") include("layers/stateless.jl") include("layers/basic.jl") +include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl new file mode 100644 index 00000000..f7ca6f02 --- /dev/null +++ b/src/layers/conv.jl @@ -0,0 +1,14 @@ +struct Conv2D{F,A} + σ::F + weight::A + stride::Int +end + +Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = initn, stride = 1) = + Conv2D(σ, param(initn(k..., ch...)), stride) + +Flux.treelike(Conv2D) + +# (c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride)) +(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight)) From 9b833a434525bc7afc00dd95c3799b71784f84d1 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 16:17:39 +0000 Subject: [PATCH 12/23] more onehot indexing --- src/onehot.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index f94fb93e..4f121958 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -18,7 +18,9 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] From 6890a615879464aa0a2f4efdee9cb4406eb14e9f Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 16:17:45 +0000 Subject: [PATCH 13/23] todo --- src/tracker/back.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index d6a48409..b4cd27c6 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -36,6 +36,9 @@ end # Interface methods +# TODO: if an error occurs in `back` the refcounts will be broken +# and `back` will silently fail to update. + function back!(x::TrackedArray, Δ) scan(x) back(x, Δ) From 73ae25289d9b902fab998de686b31a4005ea2858 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 16:18:01 +0000 Subject: [PATCH 14/23] remove old util --- src/utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index afe926d9..bba3e416 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,8 +4,6 @@ initn(dims...) = randn(dims...)/100 glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims))) glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims))) -flatten(xs) = reshape(xs, size(xs, 1), :) - unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) From 386eafc44393c4ad2d2c9a60438d6355e1702760 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 16:18:16 +0000 Subject: [PATCH 15/23] reshape --- src/tracker/lib.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 57474933..71d93e88 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys) @back(ys, Δ[size(xs,1)+1:end, i...]) end +Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = + TrackedArray(Call(reshape, xs, dims...)) + +back(::typeof(reshape), Δ, xs::TrackedArray, _...) = + back(xs, reshape(Δ, size(xs))) + # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) From 51f93d9f0e0d3393da4adcde58ba4eb7e12225b0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 16:24:45 +0000 Subject: [PATCH 16/23] conv polish --- src/layers/conv.jl | 10 ++++++++-- src/tracker/lib.jl | 17 +++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f7ca6f02..82d90029 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -10,5 +10,11 @@ Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; Flux.treelike(Conv2D) -# (c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride)) -(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight)) +(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride)) + +function Base.show(io::IO, l::Conv2D) + print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")") + print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 71d93e88..580992ef 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -136,13 +136,18 @@ softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) -conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w)) -conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w)) -conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w)) +_conv2d(x, w, stride) = conv2d(x, w, stride = stride) -function back(::typeof(conv2d), Δ, x, w) - @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ)) - @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ)) +conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1) = + TrackedArray(Call(_conv2d, x, w, stride)) +conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1) = + TrackedArray(Call(_conv2d, x, w, stride)) +conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1) = + TrackedArray(Call(_conv2d, x, w, stride)) + +function back(::typeof(_conv2d), Δ, x, w, stride) + @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride)) + @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride)) end _pool(x, k, mode) = pool(x, window = k, mode = mode) From 269d8f36b9a766301197d0176a405acdd841b890 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 18 Dec 2017 18:05:38 +0000 Subject: [PATCH 17/23] conv padding --- src/layers/conv.jl | 7 ++++--- src/tracker/lib.jl | 21 +++++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 82d90029..d73d1ad9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -2,15 +2,16 @@ struct Conv2D{F,A} σ::F weight::A stride::Int + pad::Int end Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = initn, stride = 1) = - Conv2D(σ, param(initn(k..., ch...)), stride) + init = initn, stride = 1, pad = 0) = + Conv2D(σ, param(initn(k..., ch...)), stride, pad) Flux.treelike(Conv2D) -(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride)) +(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad)) function Base.show(io::IO, l::Conv2D) print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")") diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 580992ef..2dc25e52 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -136,18 +136,19 @@ softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) -_conv2d(x, w, stride) = conv2d(x, w, stride = stride) +# TODO: can store kwargs efficiently in namedtuples +_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad) -conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1) = - TrackedArray(Call(_conv2d, x, w, stride)) -conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1) = - TrackedArray(Call(_conv2d, x, w, stride)) -conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1) = - TrackedArray(Call(_conv2d, x, w, stride)) +conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) +conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) +conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) -function back(::typeof(_conv2d), Δ, x, w, stride) - @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride)) - @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride)) +function back(::typeof(_conv2d), Δ, x, w, stride, pad) + @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad)) + @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad)) end _pool(x, k, mode) = pool(x, window = k, mode = mode) From e3577d759cf2a3b3070333275ffabb3dd5b1a566 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 18 Dec 2017 18:05:48 +0000 Subject: [PATCH 18/23] conv docs --- docs/src/models/layers.md | 1 + src/layers/conv.jl | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index d92388e1..cb0c6615 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense +Conv2D ``` ## Recurrent Layers diff --git a/src/layers/conv.jl b/src/layers/conv.jl index d73d1ad9..e267510b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,3 +1,15 @@ +""" + Conv2D(size, in=>out) + Conv2d(size, in=>out, relu) + +Standard convolutional layer. `size` should be a tuple like `(2, 2)`. +`in` and `out` specify the number of input and output channels respectively. + +Data should be stored in HWCN order. In other words, a 100×100 RGB image would +be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. + +Takes the keyword arguments `pad` and `stride`. +""" struct Conv2D{F,A} σ::F weight::A From 98b362729dd39c70d327f6f03c82fe949b4a2396 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 18 Dec 2017 18:18:14 +0000 Subject: [PATCH 19/23] pool padding --- src/tracker/lib.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 2dc25e52..40ee0458 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -151,13 +151,13 @@ function back(::typeof(_conv2d), Δ, x, w, stride, pad) @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad)) end -_pool(x, k, mode) = pool(x, window = k, mode = mode) +_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad) -pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) = - TrackedArray(Call(_pool, x, window, mode)) +pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) = + TrackedArray(Call(_pool, x, window, padding, mode)) -back_(::typeof(_pool), y, Δ, x, k, mode) = - back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode)) +back_(::typeof(_pool), y, Δ, x, k, pad, mode) = + back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad)) # Broadcasting From 7ac6b3fccf5d7b4a746160e45fc8ee5398a698bd Mon Sep 17 00:00:00 2001 From: Robin Deits Date: Wed, 3 Jan 2018 14:41:20 -0500 Subject: [PATCH 20/23] explicit forwarddiff requirement --- REQUIRE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/REQUIRE b/REQUIRE index d124b931..7c084129 100644 --- a/REQUIRE +++ b/REQUIRE @@ -3,5 +3,5 @@ DataFlow 0.2.1 Juno MacroTools 0.3.3 NNlib -ForwardDiff +ForwardDiff 0.5.0 Requires From 468f641f667faaa0aef632855c272c27df34cf60 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 8 Jan 2018 16:31:23 +0000 Subject: [PATCH 21/23] use Adapt --- REQUIRE | 1 + src/onehot.jl | 2 +- src/tracker/Tracker.jl | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/REQUIRE b/REQUIRE index 7c084129..8e718a92 100644 --- a/REQUIRE +++ b/REQUIRE @@ -5,3 +5,4 @@ MacroTools 0.3.3 NNlib ForwardDiff 0.5.0 Requires +Adapt diff --git a/src/onehot.jl b/src/onehot.jl index 4f121958..b1a1a970 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -28,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) -import NNlib.adapt +import Adapt.adapt adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 376cc617..aa2bc6ea 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -93,7 +93,7 @@ include("back.jl") include("lib.jl") include("numeric.jl") -import NNlib.adapt +import Adapt.adapt adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad)) From 2fef7991099d3a0201673bbb4fd328f78ecf12f5 Mon Sep 17 00:00:00 2001 From: Mehul Tikekar Date: Mon, 8 Jan 2018 16:45:06 -0500 Subject: [PATCH 22/23] fix typo in conv.jl (fixes #133) --- src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e267510b..85b05894 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -19,7 +19,7 @@ end Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0) = - Conv2D(σ, param(initn(k..., ch...)), stride, pad) + Conv2D(σ, param(init(k..., ch...)), stride, pad) Flux.treelike(Conv2D) From 805cb9178f7f651d35d1804fabb89c4604ee8db6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 10 Jan 2018 12:48:30 +0000 Subject: [PATCH 23/23] fixes #146 --- src/data/cmudict.jl | 4 ++-- test/data.jl | 5 +++++ test/runtests.jl | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index 9ec567b4..4307f211 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -23,14 +23,14 @@ end function symbols() load() - Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")), + Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")), "\n", keep = false)) end function rawdict() load() Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in - filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) + filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n")))) end validword(s) = ismatch(r"^[\w\-\.]+$", s) diff --git a/test/data.jl b/test/data.jl index 1b93ab3c..5a4c9ce6 100644 --- a/test/data.jl +++ b/test/data.jl @@ -1,3 +1,8 @@ using Flux.Data +using Base.Test @test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args + +@test length(CMUDict.phones()) == 39 + +@test length(CMUDict.symbols()) == 84 diff --git a/test/runtests.jl b/test/runtests.jl index 38ddb85f..553545e9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,5 +7,6 @@ include("tracker.jl") include("layers/normalisation.jl") include("layers/stateless.jl") include("optimise.jl") +include("data.jl") end