batch tweaks

This commit is contained in:
Mike J Innes 2016-10-15 18:16:04 +01:00
parent 6d53b7af47
commit 183c3b0680

View File

@ -1,4 +1,4 @@
export batch export batch, Batch
# TODO: support the Batch type only # TODO: support the Batch type only
batch(x) = reshape(x, (1,size(x)...)) batch(x) = reshape(x, (1,size(x)...))
@ -10,22 +10,23 @@ end
Base.size(b::Batch) = (size(b.data, 1),) Base.size(b::Batch) = (size(b.data, 1),)
Base.getindex(b::Batch, i) = slicedim(b.data, 1, i)::eltype(b) Base.getindex(b::Batch, i)::eltype(b) = slicedim(b.data, 1, i)
Base.setindex!(b::Batch, v, i) = b[i, :] = v Base.setindex!(b::Batch, v, i) = b[i, :] = v
function (::Type{Batch{T}}){T}(xs::T...) function (::Type{Batch{T}}){T}(xs)
length(xs) == 1 || @assert ==(map(size, xs)...) x = first(xs)
batch = similar(xs[1], length(xs), size(xs[1])...) batch = similar(x, length(xs), size(x)...)
for i = 1:length(xs) for i = 1:length(xs)
@assert size(xs[i]) == size(x)
batch[i, :] = xs[i] batch[i, :] = xs[i]
end end
return Batch{T,typeof(batch)}(batch) return Batch{T,typeof(batch)}(batch)
end end
function Batch(xs...) function Batch(xs)
xs = promote(xs...) xs = promote(xs...)
Batch{typeof(xs[1])}(xs...) Batch{typeof(xs[1])}(xs)
end end
@render Juno.Inline b::Batch begin @render Juno.Inline b::Batch begin