Merge pull request #21 from MikeInnes/no-batch

Breaking: Remove fake batching semantics
This commit is contained in:
Mike J Innes 2017-04-19 12:59:48 +01:00 committed by GitHub
commit 52bd0f9b00
12 changed files with 37 additions and 73 deletions

View File

@ -1,5 +1,3 @@
using Flux: runrawbatched
struct AlterParam
param
load
@ -109,23 +107,19 @@ import Base: @get!
# TODO: dims having its own type would be useful
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
function Flux.runmodel(m::Model, xs...)
function (m::Model)(xs...)
@mxerr m.graph.stacks begin
!isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
m.last = exec = executor(m, xs...)
exec(xs...)
end
function (m::Model)(xs...)
@mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs)
end
function Flux.back!(m::Model, Δ, xs...)
runrawbatched(Δ, xs) do Δ, xs
m.last = exec = m.execs[mapt(size, xs)]
back!(exec, Δ)
end
end
Flux.update!(m::Model, η) = (update!(m.last, η); m)

View File

@ -32,12 +32,8 @@ function runmodel(m::Model, args...)
run(m.session, m.output, Dict(zip(m.inputs, args)))
end
using Flux: runrawbatched
function (m::Model)(x)
@tferr m.stacks runrawbatched(convertel(Float32, x)) do x
output = runmodel(m, x)
end
@tferr m.stacks runmodel(m, convert.(Float32, x))
end
for f in :[back!, update!].args

View File

@ -71,8 +71,7 @@ function process_type(ex)
self = esc(:self)
quote
$(build_type(T, params))
$(@q $(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args))))))
($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...))
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing

View File

@ -30,13 +30,7 @@ function interp(ctx, f, xs...)
f(xs...))
end
# TODO: batching should be secondary
function interpmodel_(m, args...)
function interpmodel(m, args...)
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
interp(ctx, m, args...)
@ithrow interp(ctx, m, args...)
end
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...)

View File

@ -45,17 +45,3 @@ end
isbatched(x) = false
isbatched(x::Batch) = true
isbatched(xs::Tuple) = any(isbatched, xs)
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
function runbatched(f, xs...)
# TODO: decide what to do with mixed inputs
xs, batched = batchify(xs)
ys = f(xs...)
batched ? ys : mapt(unbatchone, ys)
end
runrawbatched(f, xs...) =
runbatched((xs...) -> mapt(rebatch,
f(mapt(rawbatch, xs)...)),
xs...)

View File

@ -1,4 +1,5 @@
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
Base.squeeze(xs) = squeeze(xs, 1)
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]

View File

@ -133,12 +133,10 @@ struct Stateful <: Model
end
function (m::Stateful)(x)
runrawbatched(x) do x
state, y = runmodel(m.model, (m.state...,), x)
m.state .= state
return y
end
end
stateless(m) = m
stateless(m::Stateful) = m.model
@ -150,11 +148,7 @@ end
(m::SeqModel)(x::Tuple) = m.model(x)
splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2))
joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2))
splitseq(xs) = unstack(rawbatch(xs), 2)
joinseq(xs) = rebatchseq(stack(xs, 2))
function (m::SeqModel)(x::Union{Seq,BatchSeq})
runbatched(x) do x
joinseq(m.model((splitseq(x)...,)))
end
end
(m::SeqModel)(x::BatchSeq) = joinseq(m.model((splitseq(x)...,)))

View File

@ -3,7 +3,7 @@ Flux.loadmx()
@testset "MXNet" begin
xs, ys = rand(20), rand(20)
xs, ys = rand(1, 20), rand(1, 20)
d = Affine(20, 10)
dm = mxnet(d)
@ -14,7 +14,7 @@ mm = mxnet(m)
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
@testset "Recurrence" begin
seq = Seq(rand(10) for i = 1:3)
seq = batchone(Seq(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)
rm = mxnet(r)
@test r(seq) rm(seq)
@ -25,7 +25,7 @@ end
@test dm(xs) d(xs)
@test dm(xs) d(xs)
Δ = back!(dm, randn(10), xs)
Δ = back!(dm, randn(1, 10), xs)
@test length(Δ[1]) == 20
update!(dm, 0.1)
@ -48,7 +48,7 @@ end
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
dm = mxnet(model)
e = try dm(rand(10))
e = try dm(rand(1, 10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)

View File

@ -3,7 +3,7 @@ Flux.loadtf()
@testset "TensorFlow" begin
xs = rand(20)
xs = rand(1, 20)
d = Affine(20, 10)
dt = tf(d)
@ -15,13 +15,13 @@ dt = tf(d)
Y = Tensor(d, X)
run(sess, initialize_all_variables())
@test run(sess, Y, Dict(X=>Float32.(xs'))) d(xs)'
@test run(sess, Y, Dict(X=>Float32.(xs))) d(xs)
end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
dm = tf(model)
e = try dm(rand(10))
e = try dm(rand(1, 10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)

View File

@ -1,9 +1,9 @@
@testset "Basics" begin
xs = randn(10)
xs = randn(1, 10)
d = Affine(10, 20)
@test d(xs) (xs'*d.W.x + d.b.x)[1,:]
@test d(xs) (xs*d.W.x + d.b.x)
d1 = @net x -> x * d.W + d.b
@ -24,7 +24,7 @@ end
let tlp = TLP(Affine(10, 21), Affine(20, 15))
e = try
Flux.interpmodel(tlp, rand(10))
Flux.interpmodel(tlp, rand(1, 10))
catch e
e
end
@ -33,8 +33,8 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15))
end
let m = Multi(10, 15)
x, y = rand(10), rand(10)
@test all(isapprox.(m(x, y), (m.W.x' * x, m.V.x' * y)))
x, y = rand(1, 10), rand(1, 10)
@test all(isapprox.(m(x, y), (x * m.W.x, y * m.V.x)))
@test all(isapprox.(m(x, y), Flux.interpmodel(m, x, y)))
end

View File

@ -11,9 +11,9 @@ end
@testset "RNN unrolling" begin
r = Recurrent(10, 5)
xs = [rand(10) for _ = 1:3]
_, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),))
@test ys[1] == squeeze(tanh(reshape(xs[1],(1,10)) * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x), 1)
xs = [rand(1, 10) for _ = 1:3]
_, ys = apply(stateless(unroll1(r)), xs, (r.y.x,))
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
ru = unroll(r, 3)
@test ru(Seq(xs)) == ys
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
end

View File

@ -1,5 +1,5 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param
using Flux: graph, Param, unsqueeze
using DataFlow: Line, Frame
syntax(v::Vertex) = prettify(DataFlow.syntax(v))