Merge pull request #21 from MikeInnes/no-batch
Breaking: Remove fake batching semantics
This commit is contained in:
commit
52bd0f9b00
@ -1,5 +1,3 @@
|
|||||||
using Flux: runrawbatched
|
|
||||||
|
|
||||||
struct AlterParam
|
struct AlterParam
|
||||||
param
|
param
|
||||||
load
|
load
|
||||||
@ -109,22 +107,18 @@ import Base: @get!
|
|||||||
# TODO: dims having its own type would be useful
|
# TODO: dims having its own type would be useful
|
||||||
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
||||||
|
|
||||||
function Flux.runmodel(m::Model, xs...)
|
function (m::Model)(xs...)
|
||||||
|
@mxerr m.graph.stacks begin
|
||||||
!isdefined(m, :graph) &&
|
!isdefined(m, :graph) &&
|
||||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
||||||
m.last = exec = executor(m, xs...)
|
m.last = exec = executor(m, xs...)
|
||||||
exec(xs...)
|
exec(xs...)
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::Model)(xs...)
|
|
||||||
@mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, xs...)
|
function Flux.back!(m::Model, Δ, xs...)
|
||||||
runrawbatched(Δ, xs) do Δ, xs
|
|
||||||
m.last = exec = m.execs[mapt(size, xs)]
|
m.last = exec = m.execs[mapt(size, xs)]
|
||||||
back!(exec, Δ)
|
back!(exec, Δ)
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||||
|
@ -32,12 +32,8 @@ function runmodel(m::Model, args...)
|
|||||||
run(m.session, m.output, Dict(zip(m.inputs, args)))
|
run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||||
end
|
end
|
||||||
|
|
||||||
using Flux: runrawbatched
|
|
||||||
|
|
||||||
function (m::Model)(x)
|
function (m::Model)(x)
|
||||||
@tferr m.stacks runrawbatched(convertel(Float32, x)) do x
|
@tferr m.stacks runmodel(m, convert.(Float32, x))
|
||||||
output = runmodel(m, x)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in :[back!, update!].args
|
for f in :[back!, update!].args
|
||||||
|
@ -71,8 +71,7 @@ function process_type(ex)
|
|||||||
self = esc(:self)
|
self = esc(:self)
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
$(@q $(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args))))))
|
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||||
($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...))
|
|
||||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||||
nothing
|
nothing
|
||||||
|
@ -30,13 +30,7 @@ function interp(ctx, f, xs...)
|
|||||||
f(xs...))
|
f(xs...))
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: batching should be secondary
|
function interpmodel(m, args...)
|
||||||
|
|
||||||
function interpmodel_(m, args...)
|
|
||||||
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
|
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
|
||||||
interp(ctx, m, args...)
|
@ithrow interp(ctx, m, args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
|
|
||||||
|
|
||||||
runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...)
|
|
||||||
|
@ -45,17 +45,3 @@ end
|
|||||||
isbatched(x) = false
|
isbatched(x) = false
|
||||||
isbatched(x::Batch) = true
|
isbatched(x::Batch) = true
|
||||||
isbatched(xs::Tuple) = any(isbatched, xs)
|
isbatched(xs::Tuple) = any(isbatched, xs)
|
||||||
|
|
||||||
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
|
|
||||||
|
|
||||||
function runbatched(f, xs...)
|
|
||||||
# TODO: decide what to do with mixed inputs
|
|
||||||
xs, batched = batchify(xs)
|
|
||||||
ys = f(xs...)
|
|
||||||
batched ? ys : mapt(unbatchone, ys)
|
|
||||||
end
|
|
||||||
|
|
||||||
runrawbatched(f, xs...) =
|
|
||||||
runbatched((xs...) -> mapt(rebatch,
|
|
||||||
f(mapt(rawbatch, xs)...)),
|
|
||||||
xs...)
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
Base.squeeze(xs) = squeeze(xs, 1)
|
||||||
|
|
||||||
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
||||||
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||||
|
12
src/model.jl
12
src/model.jl
@ -133,11 +133,9 @@ struct Stateful <: Model
|
|||||||
end
|
end
|
||||||
|
|
||||||
function (m::Stateful)(x)
|
function (m::Stateful)(x)
|
||||||
runrawbatched(x) do x
|
|
||||||
state, y = runmodel(m.model, (m.state...,), x)
|
state, y = runmodel(m.model, (m.state...,), x)
|
||||||
m.state .= state
|
m.state .= state
|
||||||
return y
|
return y
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
stateless(m) = m
|
stateless(m) = m
|
||||||
@ -150,11 +148,7 @@ end
|
|||||||
|
|
||||||
(m::SeqModel)(x::Tuple) = m.model(x)
|
(m::SeqModel)(x::Tuple) = m.model(x)
|
||||||
|
|
||||||
splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2))
|
splitseq(xs) = unstack(rawbatch(xs), 2)
|
||||||
joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2))
|
joinseq(xs) = rebatchseq(stack(xs, 2))
|
||||||
|
|
||||||
function (m::SeqModel)(x::Union{Seq,BatchSeq})
|
(m::SeqModel)(x::BatchSeq) = joinseq(m.model((splitseq(x)...,)))
|
||||||
runbatched(x) do x
|
|
||||||
joinseq(m.model((splitseq(x)...,)))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
@ -3,7 +3,7 @@ Flux.loadmx()
|
|||||||
|
|
||||||
@testset "MXNet" begin
|
@testset "MXNet" begin
|
||||||
|
|
||||||
xs, ys = rand(20), rand(20)
|
xs, ys = rand(1, 20), rand(1, 20)
|
||||||
d = Affine(20, 10)
|
d = Affine(20, 10)
|
||||||
|
|
||||||
dm = mxnet(d)
|
dm = mxnet(d)
|
||||||
@ -14,7 +14,7 @@ mm = mxnet(m)
|
|||||||
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||||
|
|
||||||
@testset "Recurrence" begin
|
@testset "Recurrence" begin
|
||||||
seq = Seq(rand(10) for i = 1:3)
|
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||||
r = unroll(Recurrent(10, 5), 3)
|
r = unroll(Recurrent(10, 5), 3)
|
||||||
rm = mxnet(r)
|
rm = mxnet(r)
|
||||||
@test r(seq) ≈ rm(seq)
|
@test r(seq) ≈ rm(seq)
|
||||||
@ -25,7 +25,7 @@ end
|
|||||||
@test dm(xs) ≈ d(xs)
|
@test dm(xs) ≈ d(xs)
|
||||||
@test dm(xs) ≈ d′(xs)
|
@test dm(xs) ≈ d′(xs)
|
||||||
|
|
||||||
Δ = back!(dm, randn(10), xs)
|
Δ = back!(dm, randn(1, 10), xs)
|
||||||
@test length(Δ[1]) == 20
|
@test length(Δ[1]) == 20
|
||||||
update!(dm, 0.1)
|
update!(dm, 0.1)
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ end
|
|||||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||||
info("The following warning is normal")
|
info("The following warning is normal")
|
||||||
dm = mxnet(model)
|
dm = mxnet(model)
|
||||||
e = try dm(rand(10))
|
e = try dm(rand(1, 10))
|
||||||
catch e e end
|
catch e e end
|
||||||
|
|
||||||
@test isa(e, DataFlow.Interpreter.Exception)
|
@test isa(e, DataFlow.Interpreter.Exception)
|
||||||
|
@ -3,7 +3,7 @@ Flux.loadtf()
|
|||||||
|
|
||||||
@testset "TensorFlow" begin
|
@testset "TensorFlow" begin
|
||||||
|
|
||||||
xs = rand(20)
|
xs = rand(1, 20)
|
||||||
d = Affine(20, 10)
|
d = Affine(20, 10)
|
||||||
|
|
||||||
dt = tf(d)
|
dt = tf(d)
|
||||||
@ -15,13 +15,13 @@ dt = tf(d)
|
|||||||
Y = Tensor(d, X)
|
Y = Tensor(d, X)
|
||||||
run(sess, initialize_all_variables())
|
run(sess, initialize_all_variables())
|
||||||
|
|
||||||
@test run(sess, Y, Dict(X=>Float32.(xs'))) ≈ d(xs)'
|
@test run(sess, Y, Dict(X=>Float32.(xs))) ≈ d(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Stack Traces" begin
|
@testset "Stack Traces" begin
|
||||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||||
dm = tf(model)
|
dm = tf(model)
|
||||||
e = try dm(rand(10))
|
e = try dm(rand(1, 10))
|
||||||
catch e e end
|
catch e e end
|
||||||
|
|
||||||
@test isa(e, DataFlow.Interpreter.Exception)
|
@test isa(e, DataFlow.Interpreter.Exception)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
@testset "Basics" begin
|
@testset "Basics" begin
|
||||||
|
|
||||||
xs = randn(10)
|
xs = randn(1, 10)
|
||||||
d = Affine(10, 20)
|
d = Affine(10, 20)
|
||||||
|
|
||||||
@test d(xs) ≈ (xs'*d.W.x + d.b.x)[1,:]
|
@test d(xs) ≈ (xs*d.W.x + d.b.x)
|
||||||
|
|
||||||
d1 = @net x -> x * d.W + d.b
|
d1 = @net x -> x * d.W + d.b
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ end
|
|||||||
|
|
||||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||||
e = try
|
e = try
|
||||||
Flux.interpmodel(tlp, rand(10))
|
Flux.interpmodel(tlp, rand(1, 10))
|
||||||
catch e
|
catch e
|
||||||
e
|
e
|
||||||
end
|
end
|
||||||
@ -33,8 +33,8 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
|||||||
end
|
end
|
||||||
|
|
||||||
let m = Multi(10, 15)
|
let m = Multi(10, 15)
|
||||||
x, y = rand(10), rand(10)
|
x, y = rand(1, 10), rand(1, 10)
|
||||||
@test all(isapprox.(m(x, y), (m.W.x' * x, m.V.x' * y)))
|
@test all(isapprox.(m(x, y), (x * m.W.x, y * m.V.x)))
|
||||||
@test all(isapprox.(m(x, y), Flux.interpmodel(m, x, y)))
|
@test all(isapprox.(m(x, y), Flux.interpmodel(m, x, y)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -11,9 +11,9 @@ end
|
|||||||
|
|
||||||
@testset "RNN unrolling" begin
|
@testset "RNN unrolling" begin
|
||||||
r = Recurrent(10, 5)
|
r = Recurrent(10, 5)
|
||||||
xs = [rand(10) for _ = 1:3]
|
xs = [rand(1, 10) for _ = 1:3]
|
||||||
_, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),))
|
_, ys = apply(stateless(unroll1(r)), xs, (r.y.x,))
|
||||||
@test ys[1] == squeeze(tanh(reshape(xs[1],(1,10)) * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x), 1)
|
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||||
ru = unroll(r, 3)
|
ru = unroll(r, 3)
|
||||||
@test ru(Seq(xs)) == ys
|
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
|
||||||
end
|
end
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux, DataFlow, MacroTools, Base.Test
|
using Flux, DataFlow, MacroTools, Base.Test
|
||||||
using Flux: graph, Param
|
using Flux: graph, Param, unsqueeze
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||||
|
Loading…
Reference in New Issue
Block a user