diff --git a/src/onehot.jl b/src/onehot.jl index 48167a0e..0bd694ef 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -22,6 +22,8 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) +batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) + import NNlib.adapt adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) diff --git a/src/utils.jl b/src/utils.jl index 2a86f9c5..581f9e01 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,6 +9,23 @@ unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...) stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] +batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i) + +function batch(xs) + data = similar(first(xs), size(first(xs))..., length(xs)) + for (i, x) in enumerate(xs) + data[batchindex(data, i)...] = x + end + return data +end + +Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))] + +function batchseq(xs, pad, n = maximum(length(x) for x in xs)) + xs_ = [rpad(x, n, pad) for x in xs] + [batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n] +end + # Other function accuracy(m, data)