factor out fake batching semantics
This commit is contained in:
parent
c4d815b5fc
commit
15b3ce1ada
@ -1,4 +1,4 @@
|
||||
using Flux: batchone, unbatchone, rebatch
|
||||
using Flux: runrawbatched
|
||||
|
||||
type AlterParam
|
||||
param
|
||||
@ -100,20 +100,19 @@ import Base: @get!
|
||||
|
||||
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
||||
|
||||
function (m::Model)(x::Batch)
|
||||
x′ = rawbatch(x)
|
||||
m.last = exec = @mxerr m.graph.stacks executor(m, size(x′))
|
||||
rebatch(exec(x′))
|
||||
function (m::Model)(x)
|
||||
runrawbatched(x) do x
|
||||
m.last = exec = @mxerr m.graph.stacks executor(m, size(x))
|
||||
exec(x)
|
||||
end
|
||||
end
|
||||
|
||||
(m::Model)(x) = unbatchone(m(batchone(x)))
|
||||
|
||||
function Flux.back!(m::Model, Δ::Batch, x::Batch)
|
||||
m.last = exec = m.execs[size(rawbatch(x))]
|
||||
rebatch(back!(exec, rawbatch(Δ)))
|
||||
function Flux.back!(m::Model, Δ, x)
|
||||
runrawbatched(Δ, x) do Δ, x
|
||||
m.last = exec = m.execs[size(x)]
|
||||
back!(exec, Δ)
|
||||
end
|
||||
end
|
||||
|
||||
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))
|
||||
|
||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
|
@ -73,8 +73,7 @@ function process_type(ex)
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
|
||||
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
|
||||
($self::$(esc(T)))($(args...)) = unbatchone($self(map(batchone, ($(args...),))...))
|
||||
($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||
nothing
|
||||
|
@ -34,9 +34,9 @@ end
|
||||
|
||||
# TODO: batching should be secondary
|
||||
|
||||
function interpmodel(m, args::Batch...)
|
||||
function interpmodel_(m, args...)
|
||||
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
|
||||
rebatch(@ithrow interp(ctx, m, map(rawbatch, args)...))
|
||||
interp(ctx, m, args...)
|
||||
end
|
||||
|
||||
interpmodel(m, args...) = unbatchone(interpmodel(m, batchone(args)...))
|
||||
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
|
||||
|
@ -17,19 +17,6 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
||||
Juno.trim(collect(b)))
|
||||
end
|
||||
|
||||
# Convenience methods for batch size 1
|
||||
|
||||
batchone(x) = Batch((x,))
|
||||
batchone(x::Batch) = x
|
||||
batchone(x::Tuple) = map(batchone, x)
|
||||
|
||||
function unbatchone(xs::Batch)
|
||||
@assert length(xs) == 1
|
||||
return first(xs)
|
||||
end
|
||||
|
||||
unbatchone(xs::Tuple) = map(unbatchone, xs)
|
||||
|
||||
function rebatch(xs)
|
||||
dims = ndims(xs)-1
|
||||
T = Array{eltype(xs),dims}
|
||||
@ -37,8 +24,38 @@ function rebatch(xs)
|
||||
Batch{T,B}(xs)
|
||||
end
|
||||
|
||||
rebatch(xs::Tuple) = map(rebatch, xs)
|
||||
|
||||
convertel(T::Type, xs::Batch) =
|
||||
isa(eltype(eltype(xs)), T) ? xs :
|
||||
Batch(map(x->convertel(T, x), xs))
|
||||
|
||||
# Add batching semantics to functions operating on raw arrays
|
||||
# TODO: remove this in favour of full batching semantics
|
||||
|
||||
mapt(f, x) = f(x)
|
||||
mapt(f, xs::Tuple) = map(f, xs)
|
||||
|
||||
batchone(x) = Batch((x,))
|
||||
batchone(x::Batch) = x
|
||||
|
||||
function unbatchone(xs::Batch)
|
||||
@assert length(xs) == 1
|
||||
return first(xs)
|
||||
end
|
||||
|
||||
isbatched(x) = false
|
||||
isbatched(x::Batch) = true
|
||||
isbatched(xs::Tuple) = any(isbatched, xs)
|
||||
|
||||
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
|
||||
|
||||
function runbatched(f, xs...)
|
||||
# TODO: decide what to do with mixed inputs
|
||||
xs, batched = batchify(xs)
|
||||
ys = f(xs...)
|
||||
batched ? ys : mapt(unbatchone, ys)
|
||||
end
|
||||
|
||||
runrawbatched(f, xs...) =
|
||||
runbatched((xs...) -> mapt(rebatch,
|
||||
f(mapt(rawbatch, xs)...)),
|
||||
xs...)
|
||||
|
Loading…
Reference in New Issue
Block a user