factor out fake batching semantics

This commit is contained in:
Mike J Innes 2017-03-09 00:13:26 +00:00
parent c4d815b5fc
commit 15b3ce1ada
4 changed files with 47 additions and 32 deletions

View File

@ -1,4 +1,4 @@
using Flux: batchone, unbatchone, rebatch
using Flux: runrawbatched
type AlterParam
param
@ -100,21 +100,20 @@ import Base: @get!
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
function (m::Model)(x::Batch)
x = rawbatch(x)
m.last = exec = @mxerr m.graph.stacks executor(m, size(x))
rebatch(exec(x))
function (m::Model)(x)
runrawbatched(x) do x
m.last = exec = @mxerr m.graph.stacks executor(m, size(x))
exec(x)
end
end
(m::Model)(x) = unbatchone(m(batchone(x)))
function Flux.back!(m::Model, Δ::Batch, x::Batch)
m.last = exec = m.execs[size(rawbatch(x))]
rebatch(back!(exec, rawbatch(Δ)))
function Flux.back!(m::Model, Δ, x)
runrawbatched(Δ, x) do Δ, x
m.last = exec = m.execs[size(x)]
back!(exec, Δ)
end
end
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))
Flux.update!(m::Model, η) = (update!(m.last, η); m)
# MX FeedForward interface

View File

@ -73,8 +73,7 @@ function process_type(ex)
quote
$(build_type(T, params))
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
($self::$(esc(T)))($(args...)) = unbatchone($self(map(batchone, ($(args...),))...))
($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing

View File

@ -34,9 +34,9 @@ end
# TODO: batching should be secondary
function interpmodel(m, args::Batch...)
function interpmodel_(m, args...)
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
rebatch(@ithrow interp(ctx, m, map(rawbatch, args)...))
interp(ctx, m, args...)
end
interpmodel(m, args...) = unbatchone(interpmodel(m, batchone(args)...))
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)

View File

@ -17,19 +17,6 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) =
Juno.trim(collect(b)))
end
# Convenience methods for batch size 1
batchone(x) = Batch((x,))
batchone(x::Batch) = x
batchone(x::Tuple) = map(batchone, x)
function unbatchone(xs::Batch)
@assert length(xs) == 1
return first(xs)
end
unbatchone(xs::Tuple) = map(unbatchone, xs)
function rebatch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}
@ -37,8 +24,38 @@ function rebatch(xs)
Batch{T,B}(xs)
end
rebatch(xs::Tuple) = map(rebatch, xs)
convertel(T::Type, xs::Batch) =
isa(eltype(eltype(xs)), T) ? xs :
Batch(map(x->convertel(T, x), xs))
# Add batching semantics to functions operating on raw arrays
# TODO: remove this in favour of full batching semantics
mapt(f, x) = f(x)
mapt(f, xs::Tuple) = map(f, xs)
batchone(x) = Batch((x,))
batchone(x::Batch) = x
function unbatchone(xs::Batch)
@assert length(xs) == 1
return first(xs)
end
isbatched(x) = false
isbatched(x::Batch) = true
isbatched(xs::Tuple) = any(isbatched, xs)
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
function runbatched(f, xs...)
# TODO: decide what to do with mixed inputs
xs, batched = batchify(xs)
ys = f(xs...)
batched ? ys : mapt(unbatchone, ys)
end
runrawbatched(f, xs...) =
runbatched((xs...) -> mapt(rebatch,
f(mapt(rawbatch, xs)...)),
xs...)