factor out fake batching semantics
This commit is contained in:
parent
c4d815b5fc
commit
15b3ce1ada
@ -1,4 +1,4 @@
|
|||||||
using Flux: batchone, unbatchone, rebatch
|
using Flux: runrawbatched
|
||||||
|
|
||||||
type AlterParam
|
type AlterParam
|
||||||
param
|
param
|
||||||
@ -100,20 +100,19 @@ import Base: @get!
|
|||||||
|
|
||||||
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
||||||
|
|
||||||
function (m::Model)(x::Batch)
|
function (m::Model)(x)
|
||||||
x′ = rawbatch(x)
|
runrawbatched(x) do x
|
||||||
m.last = exec = @mxerr m.graph.stacks executor(m, size(x′))
|
m.last = exec = @mxerr m.graph.stacks executor(m, size(x))
|
||||||
rebatch(exec(x′))
|
exec(x)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::Model)(x) = unbatchone(m(batchone(x)))
|
function Flux.back!(m::Model, Δ, x)
|
||||||
|
runrawbatched(Δ, x) do Δ, x
|
||||||
function Flux.back!(m::Model, Δ::Batch, x::Batch)
|
m.last = exec = m.execs[size(x)]
|
||||||
m.last = exec = m.execs[size(rawbatch(x))]
|
back!(exec, Δ)
|
||||||
rebatch(back!(exec, rawbatch(Δ)))
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))
|
|
||||||
|
|
||||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||||
|
|
||||||
|
@ -73,8 +73,7 @@ function process_type(ex)
|
|||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
|
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
|
||||||
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
|
($self::$(esc(T)))($(args...)) = runrawbatched((xs...) -> runmodel($self, xs...), $(args...))
|
||||||
($self::$(esc(T)))($(args...)) = unbatchone($self(map(batchone, ($(args...),))...))
|
|
||||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||||
nothing
|
nothing
|
||||||
|
@ -34,9 +34,9 @@ end
|
|||||||
|
|
||||||
# TODO: batching should be secondary
|
# TODO: batching should be secondary
|
||||||
|
|
||||||
function interpmodel(m, args::Batch...)
|
function interpmodel_(m, args...)
|
||||||
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
|
ctx = Context(mux(iline, ilambda, iconst, iargs, ituple, interp))
|
||||||
rebatch(@ithrow interp(ctx, m, map(rawbatch, args)...))
|
interp(ctx, m, args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
interpmodel(m, args...) = unbatchone(interpmodel(m, batchone(args)...))
|
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
|
||||||
|
@ -17,19 +17,6 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
|||||||
Juno.trim(collect(b)))
|
Juno.trim(collect(b)))
|
||||||
end
|
end
|
||||||
|
|
||||||
# Convenience methods for batch size 1
|
|
||||||
|
|
||||||
batchone(x) = Batch((x,))
|
|
||||||
batchone(x::Batch) = x
|
|
||||||
batchone(x::Tuple) = map(batchone, x)
|
|
||||||
|
|
||||||
function unbatchone(xs::Batch)
|
|
||||||
@assert length(xs) == 1
|
|
||||||
return first(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
unbatchone(xs::Tuple) = map(unbatchone, xs)
|
|
||||||
|
|
||||||
function rebatch(xs)
|
function rebatch(xs)
|
||||||
dims = ndims(xs)-1
|
dims = ndims(xs)-1
|
||||||
T = Array{eltype(xs),dims}
|
T = Array{eltype(xs),dims}
|
||||||
@ -37,8 +24,38 @@ function rebatch(xs)
|
|||||||
Batch{T,B}(xs)
|
Batch{T,B}(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
rebatch(xs::Tuple) = map(rebatch, xs)
|
|
||||||
|
|
||||||
convertel(T::Type, xs::Batch) =
|
convertel(T::Type, xs::Batch) =
|
||||||
isa(eltype(eltype(xs)), T) ? xs :
|
isa(eltype(eltype(xs)), T) ? xs :
|
||||||
Batch(map(x->convertel(T, x), xs))
|
Batch(map(x->convertel(T, x), xs))
|
||||||
|
|
||||||
|
# Add batching semantics to functions operating on raw arrays
|
||||||
|
# TODO: remove this in favour of full batching semantics
|
||||||
|
|
||||||
|
mapt(f, x) = f(x)
|
||||||
|
mapt(f, xs::Tuple) = map(f, xs)
|
||||||
|
|
||||||
|
batchone(x) = Batch((x,))
|
||||||
|
batchone(x::Batch) = x
|
||||||
|
|
||||||
|
function unbatchone(xs::Batch)
|
||||||
|
@assert length(xs) == 1
|
||||||
|
return first(xs)
|
||||||
|
end
|
||||||
|
|
||||||
|
isbatched(x) = false
|
||||||
|
isbatched(x::Batch) = true
|
||||||
|
isbatched(xs::Tuple) = any(isbatched, xs)
|
||||||
|
|
||||||
|
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
|
||||||
|
|
||||||
|
function runbatched(f, xs...)
|
||||||
|
# TODO: decide what to do with mixed inputs
|
||||||
|
xs, batched = batchify(xs)
|
||||||
|
ys = f(xs...)
|
||||||
|
batched ? ys : mapt(unbatchone, ys)
|
||||||
|
end
|
||||||
|
|
||||||
|
runrawbatched(f, xs...) =
|
||||||
|
runbatched((xs...) -> mapt(rebatch,
|
||||||
|
f(mapt(rawbatch, xs)...)),
|
||||||
|
xs...)
|
||||||
|
Loading…
Reference in New Issue
Block a user