From 15b3ce1adad9e07367e68c8c3758380b02d6a4a1 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 9 Mar 2017 00:13:26 +0000 Subject: [PATCH] factor out fake batching semantics --- src/backend/mxnet/model.jl | 23 +++++++++---------- src/compiler/code.jl | 3 +-- src/compiler/interp.jl | 6 ++--- src/dims/batching.jl | 47 ++++++++++++++++++++++++++------------ 4 files changed, 47 insertions(+), 32 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index eec367ae..6a5fa15c 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,4 +1,4 @@ -using Flux: batchone, unbatchone, rebatch +using Flux: runrawbatched type AlterParam param @@ -100,21 +100,20 @@ import Base: @get! executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input)) -function (m::Model)(x::Batch) - x′ = rawbatch(x) - m.last = exec = @mxerr m.graph.stacks executor(m, size(x′)) - rebatch(exec(x′)) +function (m::Model)(x) + runrawbatched(x) do x + m.last = exec = @mxerr m.graph.stacks executor(m, size(x)) + exec(x) + end end -(m::Model)(x) = unbatchone(m(batchone(x))) - -function Flux.back!(m::Model, Δ::Batch, x::Batch) - m.last = exec = m.execs[size(rawbatch(x))] - rebatch(back!(exec, rawbatch(Δ))) +function Flux.back!(m::Model, Δ, x) + runrawbatched(Δ, x) do Δ, x + m.last = exec = m.execs[size(x)] + back!(exec, Δ) + end end -Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x))) - Flux.update!(m::Model, η) = (update!(m.last, η); m) # MX FeedForward interface diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 712ad9a8..e3e08f26 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -73,8 +73,7 @@ function process_type(ex) quote $(build_type(T, params)) $(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args))))) - ($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...))) - ($self::$(esc(T)))($(args...)) = unbatchone($self(map(batchone, ($(args...),))...)) + ($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...)) $(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);) $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args)))) nothing diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index 6f7e3d50..2d29f677 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -34,9 +34,9 @@ end # TODO: batching should be secondary -function interpmodel(m, args::Batch...) +function interpmodel_(m, args...) ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp)) - rebatch(@ithrow interp(ctx, m, map(rawbatch, args)...)) + interp(ctx, m, args...) end -interpmodel(m, args...) = unbatchone(interpmodel(m, batchone(args)...)) +interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...) diff --git a/src/dims/batching.jl b/src/dims/batching.jl index ee78f717..14e6aded 100644 --- a/src/dims/batching.jl +++ b/src/dims/batching.jl @@ -17,19 +17,6 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) = Juno.trim(collect(b))) end -# Convenience methods for batch size 1 - -batchone(x) = Batch((x,)) -batchone(x::Batch) = x -batchone(x::Tuple) = map(batchone, x) - -function unbatchone(xs::Batch) - @assert length(xs) == 1 - return first(xs) -end - -unbatchone(xs::Tuple) = map(unbatchone, xs) - function rebatch(xs) dims = ndims(xs)-1 T = Array{eltype(xs),dims} @@ -37,8 +24,38 @@ function rebatch(xs) Batch{T,B}(xs) end -rebatch(xs::Tuple) = map(rebatch, xs) - convertel(T::Type, xs::Batch) = isa(eltype(eltype(xs)), T) ? xs : Batch(map(x->convertel(T, x), xs)) + +# Add batching semantics to functions operating on raw arrays +# TODO: remove this in favour of full batching semantics + +mapt(f, x) = f(x) +mapt(f, xs::Tuple) = map(f, xs) + +batchone(x) = Batch((x,)) +batchone(x::Batch) = x + +function unbatchone(xs::Batch) + @assert length(xs) == 1 + return first(xs) +end + +isbatched(x) = false +isbatched(x::Batch) = true +isbatched(xs::Tuple) = any(isbatched, xs) + +batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false) + +function runbatched(f, xs...) + # TODO: decide what to do with mixed inputs + xs, batched = batchify(xs) + ys = f(xs...) + batched ? ys : mapt(unbatchone, ys) +end + +runrawbatched(f, xs...) = + runbatched((xs...) -> mapt(rebatch, + f(mapt(rawbatch, xs)...)), + xs...)