Flux.jl/src/dims/batching.jl

62 lines
1.4 KiB
Julia
Raw Normal View History

2016-12-15 21:37:07 +00:00
export Batch, batchone
2016-10-25 12:48:30 +00:00
immutable Batch{T,S} <: AbstractVector{T}
data::CatMat{T,S}
end
@forward Batch.data size, eltype, getindex, setindex!, rawbatch
Batch(xs) = Batch(CatMat(xs))
convert{T,S}(::Type{Batch{T,S}},storage::S) =
Batch{T,S}(storage)
@render Juno.Inline b::Batch begin
Tree(Row(Text("Batch of "), eltype(b),
Juno.fade("[$(length(b))]")),
Juno.trim(collect(b)))
end
2017-01-24 10:24:30 +00:00
2017-03-09 00:13:26 +00:00
function rebatch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}
B = Array{eltype(xs),dims+1}
Batch{T,B}(xs)
end
convertel(T::Type, xs::Batch) =
2017-03-14 15:21:18 +00:00
eltype(eltype(xs)) isa T ? xs :
2017-03-09 00:13:26 +00:00
Batch(map(x->convertel(T, x), xs))
# Add batching semantics to functions operating on raw arrays
# TODO: remove this in favour of full batching semantics
mapt(f, x) = f(x)
mapt(f, xs::Tuple) = map(f, xs)
2017-01-24 10:24:30 +00:00
batchone(x) = Batch((x,))
batchone(x::Batch) = x
function unbatchone(xs::Batch)
@assert length(xs) == 1
return first(xs)
end
2017-03-09 00:13:26 +00:00
isbatched(x) = false
isbatched(x::Batch) = true
isbatched(xs::Tuple) = any(isbatched, xs)
2017-01-24 10:24:30 +00:00
2017-03-09 00:13:26 +00:00
batchify(xs) = isbatched(xs) ? (xs, true) : (mapt(batchone, xs), false)
2017-01-26 18:07:06 +00:00
2017-03-09 00:13:26 +00:00
function runbatched(f, xs...)
# TODO: decide what to do with mixed inputs
xs, batched = batchify(xs)
ys = f(xs...)
batched ? ys : mapt(unbatchone, ys)
end
2017-03-06 16:12:03 +00:00
2017-03-09 00:13:26 +00:00
runrawbatched(f, xs...) =
runbatched((xs...) -> mapt(rebatch,
f(mapt(rawbatch, xs)...)),
xs...)