Merge pull request #21 from MikeInnes/no-batch
Breaking: Remove fake batching semantics
This commit is contained in:
commit
52bd0f9b00
@ -1,5 +1,3 @@
|
||||
using Flux: runrawbatched
|
||||
|
||||
struct AlterParam
|
||||
param
|
||||
load
|
||||
@ -109,23 +107,19 @@ import Base: @get!
|
||||
# TODO: dims having its own type would be useful
|
||||
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
||||
|
||||
function Flux.runmodel(m::Model, xs...)
|
||||
function (m::Model)(xs...)
|
||||
@mxerr m.graph.stacks begin
|
||||
!isdefined(m, :graph) &&
|
||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
||||
m.last = exec = executor(m, xs...)
|
||||
exec(xs...)
|
||||
end
|
||||
|
||||
function (m::Model)(xs...)
|
||||
@mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs)
|
||||
end
|
||||
|
||||
function Flux.back!(m::Model, Δ, xs...)
|
||||
runrawbatched(Δ, xs) do Δ, xs
|
||||
m.last = exec = m.execs[mapt(size, xs)]
|
||||
back!(exec, Δ)
|
||||
end
|
||||
end
|
||||
|
||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
|
@ -32,12 +32,8 @@ function runmodel(m::Model, args...)
|
||||
run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||
end
|
||||
|
||||
using Flux: runrawbatched
|
||||
|
||||
function (m::Model)(x)
|
||||
@tferr m.stacks runrawbatched(convertel(Float32, x)) do x
|
||||
output = runmodel(m, x)
|
||||
end
|
||||
@tferr m.stacks runmodel(m, convert.(Float32, x))
|
||||
end
|
||||
|
||||
for f in :[back!, update!].args
|
||||
|
@ -71,8 +71,7 @@ function process_type(ex)
|
||||
self = esc(:self)
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
$(@q $(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args))))))
|
||||
($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...))
|
||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||
nothing
|
||||
|
@ -30,13 +30,7 @@ function interp(ctx, f, xs...)
|
||||
f(xs...))
|
||||
end
|
||||
|
||||
# TODO: batching should be secondary
|
||||
|
||||
function interpmodel_(m, args...)
|
||||
function interpmodel(m, args...)
|
||||
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
|
||||
interp(ctx, m, args...)
|
||||
@ithrow interp(ctx, m, args...)
|
||||
end
|
||||
|
||||
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
|
||||
|
||||
runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...)
|
||||
|
@ -45,17 +45,3 @@ end
|
||||
isbatched(x) = false
|
||||
isbatched(x::Batch) = true
|
||||
isbatched(xs::Tuple) = any(isbatched, xs)
|
||||
|
||||
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
|
||||
|
||||
function runbatched(f, xs...)
|
||||
# TODO: decide what to do with mixed inputs
|
||||
xs, batched = batchify(xs)
|
||||
ys = f(xs...)
|
||||
batched ? ys : mapt(unbatchone, ys)
|
||||
end
|
||||
|
||||
runrawbatched(f, xs...) =
|
||||
runbatched((xs...) -> mapt(rebatch,
|
||||
f(mapt(rawbatch, xs)...)),
|
||||
xs...)
|
||||
|
@ -1,4 +1,5 @@
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
Base.squeeze(xs) = squeeze(xs, 1)
|
||||
|
||||
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
|
12
src/model.jl
12
src/model.jl
@ -133,12 +133,10 @@ struct Stateful <: Model
|
||||
end
|
||||
|
||||
function (m::Stateful)(x)
|
||||
runrawbatched(x) do x
|
||||
state, y = runmodel(m.model, (m.state...,), x)
|
||||
m.state .= state
|
||||
return y
|
||||
end
|
||||
end
|
||||
|
||||
stateless(m) = m
|
||||
stateless(m::Stateful) = m.model
|
||||
@ -150,11 +148,7 @@ end
|
||||
|
||||
(m::SeqModel)(x::Tuple) = m.model(x)
|
||||
|
||||
splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2))
|
||||
joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2))
|
||||
splitseq(xs) = unstack(rawbatch(xs), 2)
|
||||
joinseq(xs) = rebatchseq(stack(xs, 2))
|
||||
|
||||
function (m::SeqModel)(x::Union{Seq,BatchSeq})
|
||||
runbatched(x) do x
|
||||
joinseq(m.model((splitseq(x)...,)))
|
||||
end
|
||||
end
|
||||
(m::SeqModel)(x::BatchSeq) = joinseq(m.model((splitseq(x)...,)))
|
||||
|
@ -3,7 +3,7 @@ Flux.loadmx()
|
||||
|
||||
@testset "MXNet" begin
|
||||
|
||||
xs, ys = rand(20), rand(20)
|
||||
xs, ys = rand(1, 20), rand(1, 20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dm = mxnet(d)
|
||||
@ -14,7 +14,7 @@ mm = mxnet(m)
|
||||
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||
|
||||
@testset "Recurrence" begin
|
||||
seq = Seq(rand(10) for i = 1:3)
|
||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
rm = mxnet(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
@ -25,7 +25,7 @@ end
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≈ d′(xs)
|
||||
|
||||
Δ = back!(dm, randn(10), xs)
|
||||
Δ = back!(dm, randn(1, 10), xs)
|
||||
@test length(Δ[1]) == 20
|
||||
update!(dm, 0.1)
|
||||
|
||||
@ -48,7 +48,7 @@ end
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
info("The following warning is normal")
|
||||
dm = mxnet(model)
|
||||
e = try dm(rand(10))
|
||||
e = try dm(rand(1, 10))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
|
@ -3,7 +3,7 @@ Flux.loadtf()
|
||||
|
||||
@testset "TensorFlow" begin
|
||||
|
||||
xs = rand(20)
|
||||
xs = rand(1, 20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dt = tf(d)
|
||||
@ -15,13 +15,13 @@ dt = tf(d)
|
||||
Y = Tensor(d, X)
|
||||
run(sess, initialize_all_variables())
|
||||
|
||||
@test run(sess, Y, Dict(X=>Float32.(xs'))) ≈ d(xs)'
|
||||
@test run(sess, Y, Dict(X=>Float32.(xs))) ≈ d(xs)
|
||||
end
|
||||
|
||||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
dm = tf(model)
|
||||
e = try dm(rand(10))
|
||||
e = try dm(rand(1, 10))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
|
@ -1,9 +1,9 @@
|
||||
@testset "Basics" begin
|
||||
|
||||
xs = randn(10)
|
||||
xs = randn(1, 10)
|
||||
d = Affine(10, 20)
|
||||
|
||||
@test d(xs) ≈ (xs'*d.W.x + d.b.x)[1,:]
|
||||
@test d(xs) ≈ (xs*d.W.x + d.b.x)
|
||||
|
||||
d1 = @net x -> x * d.W + d.b
|
||||
|
||||
@ -24,7 +24,7 @@ end
|
||||
|
||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||
e = try
|
||||
Flux.interpmodel(tlp, rand(10))
|
||||
Flux.interpmodel(tlp, rand(1, 10))
|
||||
catch e
|
||||
e
|
||||
end
|
||||
@ -33,8 +33,8 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||
end
|
||||
|
||||
let m = Multi(10, 15)
|
||||
x, y = rand(10), rand(10)
|
||||
@test all(isapprox.(m(x, y), (m.W.x' * x, m.V.x' * y)))
|
||||
x, y = rand(1, 10), rand(1, 10)
|
||||
@test all(isapprox.(m(x, y), (x * m.W.x, y * m.V.x)))
|
||||
@test all(isapprox.(m(x, y), Flux.interpmodel(m, x, y)))
|
||||
end
|
||||
|
||||
|
@ -11,9 +11,9 @@ end
|
||||
|
||||
@testset "RNN unrolling" begin
|
||||
r = Recurrent(10, 5)
|
||||
xs = [rand(10) for _ = 1:3]
|
||||
_, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),))
|
||||
@test ys[1] == squeeze(tanh(reshape(xs[1],(1,10)) * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x), 1)
|
||||
xs = [rand(1, 10) for _ = 1:3]
|
||||
_, ys = apply(stateless(unroll1(r)), xs, (r.y.x,))
|
||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||
ru = unroll(r, 3)
|
||||
@test ru(Seq(xs)) == ys
|
||||
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
|
||||
end
|
||||
|
@ -1,5 +1,5 @@
|
||||
using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param
|
||||
using Flux: graph, Param, unsqueeze
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||
|
Loading…
Reference in New Issue
Block a user