batch tweaks
This commit is contained in:
parent
6d53b7af47
commit
183c3b0680
|
@ -1,4 +1,4 @@
|
|||
export batch
|
||||
export batch, Batch
|
||||
|
||||
# TODO: support the Batch type only
|
||||
batch(x) = reshape(x, (1,size(x)...))
|
||||
|
@ -10,22 +10,23 @@ end
|
|||
|
||||
Base.size(b::Batch) = (size(b.data, 1),)
|
||||
|
||||
Base.getindex(b::Batch, i) = slicedim(b.data, 1, i)::eltype(b)
|
||||
Base.getindex(b::Batch, i)::eltype(b) = slicedim(b.data, 1, i)
|
||||
|
||||
Base.setindex!(b::Batch, v, i) = b[i, :] = v
|
||||
|
||||
function (::Type{Batch{T}}){T}(xs::T...)
|
||||
length(xs) == 1 || @assert ==(map(size, xs)...)
|
||||
batch = similar(xs[1], length(xs), size(xs[1])...)
|
||||
function (::Type{Batch{T}}){T}(xs)
|
||||
x = first(xs)
|
||||
batch = similar(x, length(xs), size(x)...)
|
||||
for i = 1:length(xs)
|
||||
@assert size(xs[i]) == size(x)
|
||||
batch[i, :] = xs[i]
|
||||
end
|
||||
return Batch{T,typeof(batch)}(batch)
|
||||
end
|
||||
|
||||
function Batch(xs...)
|
||||
function Batch(xs)
|
||||
xs′ = promote(xs...)
|
||||
Batch{typeof(xs′[1])}(xs′...)
|
||||
Batch{typeof(xs′[1])}(xs′)
|
||||
end
|
||||
|
||||
@render Juno.Inline b::Batch begin
|
||||
|
|
Loading…
Reference in New Issue