batch tweaks

This commit is contained in:
Mike J Innes 2016-10-15 18:16:04 +01:00
parent 6d53b7af47
commit 183c3b0680
1 changed files with 8 additions and 7 deletions

View File

@ -1,4 +1,4 @@
export batch
export batch, Batch
# TODO: support the Batch type only
batch(x) = reshape(x, (1,size(x)...))
@ -10,22 +10,23 @@ end
Base.size(b::Batch) = (size(b.data, 1),)
Base.getindex(b::Batch, i) = slicedim(b.data, 1, i)::eltype(b)
Base.getindex(b::Batch, i)::eltype(b) = slicedim(b.data, 1, i)
Base.setindex!(b::Batch, v, i) = b[i, :] = v
function (::Type{Batch{T}}){T}(xs::T...)
length(xs) == 1 || @assert ==(map(size, xs)...)
batch = similar(xs[1], length(xs), size(xs[1])...)
function (::Type{Batch{T}}){T}(xs)
x = first(xs)
batch = similar(x, length(xs), size(x)...)
for i = 1:length(xs)
@assert size(xs[i]) == size(x)
batch[i, :] = xs[i]
end
return Batch{T,typeof(batch)}(batch)
end
function Batch(xs...)
function Batch(xs)
xs = promote(xs...)
Batch{typeof(xs[1])}(xs...)
Batch{typeof(xs[1])}(xs)
end
@render Juno.Inline b::Batch begin