Flux.jl/src/dims/batching.jl

33 lines
684 B
Julia
Raw Normal View History

2016-12-15 21:37:07 +00:00
export Batch, batchone
2016-10-25 12:48:30 +00:00
2017-03-14 17:56:03 +00:00
struct Batch{T,S} <: AbstractVector{T}
2016-10-25 12:48:30 +00:00
data::CatMat{T,S}
end
@forward Batch.data size, eltype, getindex, setindex!, rawbatch
Batch(xs) = Batch(CatMat(xs))
convert{T,S}(::Type{Batch{T,S}},storage::S) =
Batch{T,S}(storage)
@render Juno.Inline b::Batch begin
Tree(Row(Text("Batch of "), eltype(b),
Juno.fade("[$(length(b))]")),
Juno.trim(collect(b)))
end
2017-01-24 10:24:30 +00:00
2017-03-09 00:13:26 +00:00
function rebatch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}
B = Array{eltype(xs),dims+1}
Batch{T,B}(xs)
end
convertel(T::Type, xs::Batch) =
2017-03-14 15:21:18 +00:00
eltype(eltype(xs)) isa T ? xs :
2017-03-09 00:13:26 +00:00
Batch(map(x->convertel(T, x), xs))
2017-04-19 12:26:37 +00:00
batchone(x) = Batch((x,))
batchone(x::Batch) = x