2016-12-15 21:37:07 +00:00
|
|
|
export Batch, batchone
|
2016-10-25 12:48:30 +00:00
|
|
|
|
|
|
|
immutable Batch{T,S} <: AbstractVector{T}
|
|
|
|
data::CatMat{T,S}
|
|
|
|
end
|
|
|
|
|
|
|
|
@forward Batch.data size, eltype, getindex, setindex!, rawbatch
|
|
|
|
|
|
|
|
Batch(xs) = Batch(CatMat(xs))
|
|
|
|
|
|
|
|
convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
|
|
|
Batch{T,S}(storage)
|
|
|
|
|
|
|
|
@render Juno.Inline b::Batch begin
|
|
|
|
Tree(Row(Text("Batch of "), eltype(b),
|
|
|
|
Juno.fade("[$(length(b))]")),
|
|
|
|
Juno.trim(collect(b)))
|
|
|
|
end
|
2017-01-24 10:24:30 +00:00
|
|
|
|
2017-03-09 00:13:26 +00:00
|
|
|
function rebatch(xs)
|
|
|
|
dims = ndims(xs)-1
|
|
|
|
T = Array{eltype(xs),dims}
|
|
|
|
B = Array{eltype(xs),dims+1}
|
|
|
|
Batch{T,B}(xs)
|
|
|
|
end
|
|
|
|
|
|
|
|
convertel(T::Type, xs::Batch) =
|
2017-03-14 15:21:18 +00:00
|
|
|
eltype(eltype(xs)) isa T ? xs :
|
2017-03-09 00:13:26 +00:00
|
|
|
Batch(map(x->convertel(T, x), xs))
|
|
|
|
|
|
|
|
# Add batching semantics to functions operating on raw arrays
|
|
|
|
# TODO: remove this in favour of full batching semantics
|
|
|
|
|
|
|
|
mapt(f, x) = f(x)
|
|
|
|
mapt(f, xs::Tuple) = map(f, xs)
|
2017-01-24 10:24:30 +00:00
|
|
|
|
|
|
|
batchone(x) = Batch((x,))
|
|
|
|
batchone(x::Batch) = x
|
|
|
|
|
|
|
|
function unbatchone(xs::Batch)
|
|
|
|
@assert length(xs) == 1
|
|
|
|
return first(xs)
|
|
|
|
end
|
|
|
|
|
2017-03-09 00:13:26 +00:00
|
|
|
isbatched(x) = false
|
|
|
|
isbatched(x::Batch) = true
|
|
|
|
isbatched(xs::Tuple) = any(isbatched, xs)
|
2017-01-24 10:24:30 +00:00
|
|
|
|
2017-03-09 00:13:26 +00:00
|
|
|
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
|
2017-01-26 18:07:06 +00:00
|
|
|
|
2017-03-09 00:13:26 +00:00
|
|
|
function runbatched(f, xs...)
|
|
|
|
# TODO: decide what to do with mixed inputs
|
|
|
|
xs, batched = batchify(xs)
|
|
|
|
ys = f(xs...)
|
|
|
|
batched ? ys : mapt(unbatchone, ys)
|
|
|
|
end
|
2017-03-06 16:12:03 +00:00
|
|
|
|
2017-03-09 00:13:26 +00:00
|
|
|
runrawbatched(f, xs...) =
|
|
|
|
runbatched((xs...) -> mapt(rebatch,
|
|
|
|
f(mapt(rawbatch, xs)...)),
|
|
|
|
xs...)
|