batch data structure
This commit is contained in:
parent
b4390e6a23
commit
6d53b7af47
@ -1,6 +1,7 @@
|
|||||||
module Flux
|
module Flux
|
||||||
|
|
||||||
using MacroTools, Lazy, Flow, Juno
|
using MacroTools, Lazy, Flow, Juno
|
||||||
|
import Juno: Tree, Row
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
|
@ -1,8 +1,35 @@
|
|||||||
export batch
|
export batch
|
||||||
|
|
||||||
# Treat the first dimension as the batch index
|
# TODO: support the Batch type only
|
||||||
# TODO: custom data type for this
|
|
||||||
batch(x) = reshape(x, (1,size(x)...))
|
batch(x) = reshape(x, (1,size(x)...))
|
||||||
batch(xs...) = vcat(map(batch, xs)...)
|
batch(xs...) = vcat(map(batch, xs)...)
|
||||||
|
|
||||||
unbatch(xs) = reshape(xs, size(xs)[2:end])
|
type Batch{T,T′} <: AbstractVector{T}
|
||||||
|
data::T′
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.size(b::Batch) = (size(b.data, 1),)
|
||||||
|
|
||||||
|
Base.getindex(b::Batch, i) = slicedim(b.data, 1, i)::eltype(b)
|
||||||
|
|
||||||
|
Base.setindex!(b::Batch, v, i) = b[i, :] = v
|
||||||
|
|
||||||
|
function (::Type{Batch{T}}){T}(xs::T...)
|
||||||
|
length(xs) == 1 || @assert ==(map(size, xs)...)
|
||||||
|
batch = similar(xs[1], length(xs), size(xs[1])...)
|
||||||
|
for i = 1:length(xs)
|
||||||
|
batch[i, :] = xs[i]
|
||||||
|
end
|
||||||
|
return Batch{T,typeof(batch)}(batch)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Batch(xs...)
|
||||||
|
xs′ = promote(xs...)
|
||||||
|
Batch{typeof(xs′[1])}(xs′...)
|
||||||
|
end
|
||||||
|
|
||||||
|
@render Juno.Inline b::Batch begin
|
||||||
|
Tree(Row(Text("Batch of "), eltype(b),
|
||||||
|
Juno.fade("[$(length(b))]")),
|
||||||
|
Juno.trim(collect(b)))
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user