Merge pull request #21 from MikeInnes/no-batch

Breaking: Remove fake batching semantics
This commit is contained in:
Mike J Innes 2017-04-19 12:59:48 +01:00 committed by GitHub
commit 52bd0f9b00
12 changed files with 37 additions and 73 deletions

View File

@ -1,5 +1,3 @@
using Flux: runrawbatched
struct AlterParam struct AlterParam
param param
load load
@ -109,23 +107,19 @@ import Base: @get!
# TODO: dims having its own type would be useful # TODO: dims having its own type would be useful
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...)) executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
function Flux.runmodel(m::Model, xs...) function (m::Model)(xs...)
@mxerr m.graph.stacks begin
!isdefined(m, :graph) && !isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...)) (m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
m.last = exec = executor(m, xs...) m.last = exec = executor(m, xs...)
exec(xs...) exec(xs...)
end end
function (m::Model)(xs...)
@mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs)
end end
function Flux.back!(m::Model, Δ, xs...) function Flux.back!(m::Model, Δ, xs...)
runrawbatched(Δ, xs) do Δ, xs
m.last = exec = m.execs[mapt(size, xs)] m.last = exec = m.execs[mapt(size, xs)]
back!(exec, Δ) back!(exec, Δ)
end end
end
Flux.update!(m::Model, η) = (update!(m.last, η); m) Flux.update!(m::Model, η) = (update!(m.last, η); m)

View File

@ -32,12 +32,8 @@ function runmodel(m::Model, args...)
run(m.session, m.output, Dict(zip(m.inputs, args))) run(m.session, m.output, Dict(zip(m.inputs, args)))
end end
using Flux: runrawbatched
function (m::Model)(x) function (m::Model)(x)
@tferr m.stacks runrawbatched(convertel(Float32, x)) do x @tferr m.stacks runmodel(m, convert.(Float32, x))
output = runmodel(m, x)
end
end end
for f in :[back!, update!].args for f in :[back!, update!].args

View File

@ -71,8 +71,7 @@ function process_type(ex)
self = esc(:self) self = esc(:self)
quote quote
$(build_type(T, params)) $(build_type(T, params))
$(@q $(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))) $(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);) $(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args)))) $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing nothing

View File

@ -30,13 +30,7 @@ function interp(ctx, f, xs...)
f(xs...)) f(xs...))
end end
# TODO: batching should be secondary function interpmodel(m, args...)
function interpmodel_(m, args...)
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp)) ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
interp(ctx, m, args...) @ithrow interp(ctx, m, args...)
end end
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...)

View File

@ -45,17 +45,3 @@ end
isbatched(x) = false isbatched(x) = false
isbatched(x::Batch) = true isbatched(x::Batch) = true
isbatched(xs::Tuple) = any(isbatched, xs) isbatched(xs::Tuple) = any(isbatched, xs)
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
function runbatched(f, xs...)
# TODO: decide what to do with mixed inputs
xs, batched = batchify(xs)
ys = f(xs...)
batched ? ys : mapt(unbatchone, ys)
end
runrawbatched(f, xs...) =
runbatched((xs...) -> mapt(rebatch,
f(mapt(rawbatch, xs)...)),
xs...)

View File

@ -1,4 +1,5 @@
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
Base.squeeze(xs) = squeeze(xs, 1)
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]

View File

@ -133,12 +133,10 @@ struct Stateful <: Model
end end
function (m::Stateful)(x) function (m::Stateful)(x)
runrawbatched(x) do x
state, y = runmodel(m.model, (m.state...,), x) state, y = runmodel(m.model, (m.state...,), x)
m.state .= state m.state .= state
return y return y
end end
end
stateless(m) = m stateless(m) = m
stateless(m::Stateful) = m.model stateless(m::Stateful) = m.model
@ -150,11 +148,7 @@ end
(m::SeqModel)(x::Tuple) = m.model(x) (m::SeqModel)(x::Tuple) = m.model(x)
splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2)) splitseq(xs) = unstack(rawbatch(xs), 2)
joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2)) joinseq(xs) = rebatchseq(stack(xs, 2))
function (m::SeqModel)(x::Union{Seq,BatchSeq}) (m::SeqModel)(x::BatchSeq) = joinseq(m.model((splitseq(x)...,)))
runbatched(x) do x
joinseq(m.model((splitseq(x)...,)))
end
end

View File

@ -3,7 +3,7 @@ Flux.loadmx()
@testset "MXNet" begin @testset "MXNet" begin
xs, ys = rand(20), rand(20) xs, ys = rand(1, 20), rand(1, 20)
d = Affine(20, 10) d = Affine(20, 10)
dm = mxnet(d) dm = mxnet(d)
@ -14,7 +14,7 @@ mm = mxnet(m)
@test all(isapprox.(mm(xs, ys), m(xs, ys))) @test all(isapprox.(mm(xs, ys), m(xs, ys)))
@testset "Recurrence" begin @testset "Recurrence" begin
seq = Seq(rand(10) for i = 1:3) seq = batchone(Seq(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3) r = unroll(Recurrent(10, 5), 3)
rm = mxnet(r) rm = mxnet(r)
@test r(seq) rm(seq) @test r(seq) rm(seq)
@ -25,7 +25,7 @@ end
@test dm(xs) d(xs) @test dm(xs) d(xs)
@test dm(xs) d(xs) @test dm(xs) d(xs)
Δ = back!(dm, randn(10), xs) Δ = back!(dm, randn(1, 10), xs)
@test length(Δ[1]) == 20 @test length(Δ[1]) == 20
update!(dm, 0.1) update!(dm, 0.1)
@ -48,7 +48,7 @@ end
model = TLP(Affine(10, 20), Affine(21, 15)) model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal") info("The following warning is normal")
dm = mxnet(model) dm = mxnet(model)
e = try dm(rand(10)) e = try dm(rand(1, 10))
catch e e end catch e e end
@test isa(e, DataFlow.Interpreter.Exception) @test isa(e, DataFlow.Interpreter.Exception)

View File

@ -3,7 +3,7 @@ Flux.loadtf()
@testset "TensorFlow" begin @testset "TensorFlow" begin
xs = rand(20) xs = rand(1, 20)
d = Affine(20, 10) d = Affine(20, 10)
dt = tf(d) dt = tf(d)
@ -15,13 +15,13 @@ dt = tf(d)
Y = Tensor(d, X) Y = Tensor(d, X)
run(sess, initialize_all_variables()) run(sess, initialize_all_variables())
@test run(sess, Y, Dict(X=>Float32.(xs'))) d(xs)' @test run(sess, Y, Dict(X=>Float32.(xs))) d(xs)
end end
@testset "Stack Traces" begin @testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15)) model = TLP(Affine(10, 20), Affine(21, 15))
dm = tf(model) dm = tf(model)
e = try dm(rand(10)) e = try dm(rand(1, 10))
catch e e end catch e e end
@test isa(e, DataFlow.Interpreter.Exception) @test isa(e, DataFlow.Interpreter.Exception)

View File

@ -1,9 +1,9 @@
@testset "Basics" begin @testset "Basics" begin
xs = randn(10) xs = randn(1, 10)
d = Affine(10, 20) d = Affine(10, 20)
@test d(xs) (xs'*d.W.x + d.b.x)[1,:] @test d(xs) (xs*d.W.x + d.b.x)
d1 = @net x -> x * d.W + d.b d1 = @net x -> x * d.W + d.b
@ -24,7 +24,7 @@ end
let tlp = TLP(Affine(10, 21), Affine(20, 15)) let tlp = TLP(Affine(10, 21), Affine(20, 15))
e = try e = try
Flux.interpmodel(tlp, rand(10)) Flux.interpmodel(tlp, rand(1, 10))
catch e catch e
e e
end end
@ -33,8 +33,8 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15))
end end
let m = Multi(10, 15) let m = Multi(10, 15)
x, y = rand(10), rand(10) x, y = rand(1, 10), rand(1, 10)
@test all(isapprox.(m(x, y), (m.W.x' * x, m.V.x' * y))) @test all(isapprox.(m(x, y), (x * m.W.x, y * m.V.x)))
@test all(isapprox.(m(x, y), Flux.interpmodel(m, x, y))) @test all(isapprox.(m(x, y), Flux.interpmodel(m, x, y)))
end end

View File

@ -11,9 +11,9 @@ end
@testset "RNN unrolling" begin @testset "RNN unrolling" begin
r = Recurrent(10, 5) r = Recurrent(10, 5)
xs = [rand(10) for _ = 1:3] xs = [rand(1, 10) for _ = 1:3]
_, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),)) _, ys = apply(stateless(unroll1(r)), xs, (r.y.x,))
@test ys[1] == squeeze(tanh(reshape(xs[1],(1,10)) * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x), 1) @test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
ru = unroll(r, 3) ru = unroll(r, 3)
@test ru(Seq(xs)) == ys ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
end end

View File

@ -1,5 +1,5 @@
using Flux, DataFlow, MacroTools, Base.Test using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param using Flux: graph, Param, unsqueeze
using DataFlow: Line, Frame using DataFlow: Line, Frame
syntax(v::Vertex) = prettify(DataFlow.syntax(v)) syntax(v::Vertex) = prettify(DataFlow.syntax(v))