From 88cf6d9e612cdf893a706d2d9607503b5a5f8094 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 18 Apr 2017 20:55:59 +0100 Subject: [PATCH 1/2] sensible default for unsqueeze --- src/dims/utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dims/utils.jl b/src/dims/utils.jl index 680c85a1..7cc00476 100644 --- a/src/dims/utils.jl +++ b/src/dims/utils.jl @@ -1,4 +1,5 @@ -unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +Base.squeeze(xs) = squeeze(xs, 1) stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] From 228f7d487c2b477b73d9403f1c3a60e879120106 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 18 Apr 2017 21:04:21 +0100 Subject: [PATCH 2/2] remove fake batch semantics --- src/backend/mxnet/model.jl | 22 ++++++++-------------- src/backend/tensorflow/model.jl | 6 +----- src/compiler/code.jl | 3 +-- src/compiler/interp.jl | 10 ++-------- src/dims/batching.jl | 14 -------------- src/model.jl | 18 ++++++------------ test/backend/mxnet.jl | 8 ++++---- test/backend/tensorflow.jl | 6 +++--- test/basic.jl | 10 +++++----- test/recurrent.jl | 8 ++++---- test/runtests.jl | 2 +- 11 files changed, 35 insertions(+), 72 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 304ba170..904b90dd 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,5 +1,3 @@ -using Flux: runrawbatched - struct AlterParam param load @@ -109,22 +107,18 @@ import Base: @get! # TODO: dims having its own type would be useful executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...)) -function Flux.runmodel(m::Model, xs...) - !isdefined(m, :graph) && - (m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...)) - m.last = exec = executor(m, xs...) - exec(xs...) -end - function (m::Model)(xs...) - @mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs) + @mxerr m.graph.stacks begin + !isdefined(m, :graph) && + (m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...)) + m.last = exec = executor(m, xs...) + exec(xs...) + end end function Flux.back!(m::Model, Δ, xs...) - runrawbatched(Δ, xs) do Δ, xs - m.last = exec = m.execs[mapt(size, xs)] - back!(exec, Δ) - end + m.last = exec = m.execs[mapt(size, xs)] + back!(exec, Δ) end Flux.update!(m::Model, η) = (update!(m.last, η); m) diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index dde4df4a..63dbc576 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -32,12 +32,8 @@ function runmodel(m::Model, args...) run(m.session, m.output, Dict(zip(m.inputs, args))) end -using Flux: runrawbatched - function (m::Model)(x) - @tferr m.stacks runrawbatched(convertel(Float32, x)) do x - output = runmodel(m, x) - end + @tferr m.stacks runmodel(m, convert.(Float32, x)) end for f in :[back!, update!].args diff --git a/src/compiler/code.jl b/src/compiler/code.jl index b18aa489..f7bd77d2 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -71,8 +71,7 @@ function process_type(ex) self = esc(:self) quote $(build_type(T, params)) - $(@q $(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))) - ($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...)) + $(esc(:((self::$T)($(args...)) = $(build_forward(body, args))))) $(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);) $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args)))) nothing diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index 4419e680..129cce79 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -30,13 +30,7 @@ function interp(ctx, f, xs...) f(xs...)) end -# TODO: batching should be secondary - -function interpmodel_(m, args...) +function interpmodel(m, args...) ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp)) - interp(ctx, m, args...) + @ithrow interp(ctx, m, args...) end - -interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...) - -runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...) diff --git a/src/dims/batching.jl b/src/dims/batching.jl index d204bd18..8b9da07a 100644 --- a/src/dims/batching.jl +++ b/src/dims/batching.jl @@ -45,17 +45,3 @@ end isbatched(x) = false isbatched(x::Batch) = true isbatched(xs::Tuple) = any(isbatched, xs) - -batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false) - -function runbatched(f, xs...) - # TODO: decide what to do with mixed inputs - xs, batched = batchify(xs) - ys = f(xs...) - batched ? ys : mapt(unbatchone, ys) -end - -runrawbatched(f, xs...) = - runbatched((xs...) -> mapt(rebatch, - f(mapt(rawbatch, xs)...)), - xs...) diff --git a/src/model.jl b/src/model.jl index d1688141..9f17d502 100644 --- a/src/model.jl +++ b/src/model.jl @@ -133,11 +133,9 @@ struct Stateful <: Model end function (m::Stateful)(x) - runrawbatched(x) do x - state, y = runmodel(m.model, (m.state...,), x) - m.state .= state - return y - end + state, y = runmodel(m.model, (m.state...,), x) + m.state .= state + return y end stateless(m) = m @@ -150,11 +148,7 @@ end (m::SeqModel)(x::Tuple) = m.model(x) -splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2)) -joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2)) +splitseq(xs) = unstack(rawbatch(xs), 2) +joinseq(xs) = rebatchseq(stack(xs, 2)) -function (m::SeqModel)(x::Union{Seq,BatchSeq}) - runbatched(x) do x - joinseq(m.model((splitseq(x)...,))) - end -end +(m::SeqModel)(x::BatchSeq) = joinseq(m.model((splitseq(x)...,))) diff --git a/test/backend/mxnet.jl b/test/backend/mxnet.jl index e27bbea6..b9527e2b 100644 --- a/test/backend/mxnet.jl +++ b/test/backend/mxnet.jl @@ -3,7 +3,7 @@ Flux.loadmx() @testset "MXNet" begin -xs, ys = rand(20), rand(20) +xs, ys = rand(1, 20), rand(1, 20) d = Affine(20, 10) dm = mxnet(d) @@ -14,7 +14,7 @@ mm = mxnet(m) @test all(isapprox.(mm(xs, ys), m(xs, ys))) @testset "Recurrence" begin - seq = Seq(rand(10) for i = 1:3) + seq = batchone(Seq(rand(10) for i = 1:3)) r = unroll(Recurrent(10, 5), 3) rm = mxnet(r) @test r(seq) ≈ rm(seq) @@ -25,7 +25,7 @@ end @test dm(xs) ≈ d(xs) @test dm(xs) ≈ d′(xs) - Δ = back!(dm, randn(10), xs) + Δ = back!(dm, randn(1, 10), xs) @test length(Δ[1]) == 20 update!(dm, 0.1) @@ -48,7 +48,7 @@ end model = TLP(Affine(10, 20), Affine(21, 15)) info("The following warning is normal") dm = mxnet(model) - e = try dm(rand(10)) + e = try dm(rand(1, 10)) catch e e end @test isa(e, DataFlow.Interpreter.Exception) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index f9df5c5f..459402cd 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -3,7 +3,7 @@ Flux.loadtf() @testset "TensorFlow" begin -xs = rand(20) +xs = rand(1, 20) d = Affine(20, 10) dt = tf(d) @@ -15,13 +15,13 @@ dt = tf(d) Y = Tensor(d, X) run(sess, initialize_all_variables()) - @test run(sess, Y, Dict(X=>Float32.(xs'))) ≈ d(xs)' + @test run(sess, Y, Dict(X=>Float32.(xs))) ≈ d(xs) end @testset "Stack Traces" begin model = TLP(Affine(10, 20), Affine(21, 15)) dm = tf(model) - e = try dm(rand(10)) + e = try dm(rand(1, 10)) catch e e end @test isa(e, DataFlow.Interpreter.Exception) diff --git a/test/basic.jl b/test/basic.jl index de9acaa8..134e7ff7 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,9 +1,9 @@ @testset "Basics" begin -xs = randn(10) +xs = randn(1, 10) d = Affine(10, 20) -@test d(xs) ≈ (xs'*d.W.x + d.b.x)[1,:] +@test d(xs) ≈ (xs*d.W.x + d.b.x) d1 = @net x -> x * d.W + d.b @@ -24,7 +24,7 @@ end let tlp = TLP(Affine(10, 21), Affine(20, 15)) e = try - Flux.interpmodel(tlp, rand(10)) + Flux.interpmodel(tlp, rand(1, 10)) catch e e end @@ -33,8 +33,8 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15)) end let m = Multi(10, 15) - x, y = rand(10), rand(10) - @test all(isapprox.(m(x, y), (m.W.x' * x, m.V.x' * y))) + x, y = rand(1, 10), rand(1, 10) + @test all(isapprox.(m(x, y), (x * m.W.x, y * m.V.x))) @test all(isapprox.(m(x, y), Flux.interpmodel(m, x, y))) end diff --git a/test/recurrent.jl b/test/recurrent.jl index 39241ec4..42627b9b 100644 --- a/test/recurrent.jl +++ b/test/recurrent.jl @@ -11,9 +11,9 @@ end @testset "RNN unrolling" begin r = Recurrent(10, 5) - xs = [rand(10) for _ = 1:3] - _, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),)) - @test ys[1] == squeeze(tanh(reshape(xs[1],(1,10)) * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x), 1) + xs = [rand(1, 10) for _ = 1:3] + _, ys = apply(stateless(unroll1(r)), xs, (r.y.x,)) + @test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x) ru = unroll(r, 3) - @test ru(Seq(xs)) == ys + ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys) end diff --git a/test/runtests.jl b/test/runtests.jl index 97bb9e70..ac01c5ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Flux, DataFlow, MacroTools, Base.Test -using Flux: graph, Param +using Flux: graph, Param, unsqueeze using DataFlow: Line, Frame syntax(v::Vertex) = prettify(DataFlow.syntax(v))