batch and batchseq apis

This commit is contained in:
Mike J Innes 2017-10-15 23:44:40 +01:00
parent 646720cd05
commit 9a155abecd
2 changed files with 19 additions and 0 deletions

View File

@ -22,6 +22,8 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import NNlib.adapt import NNlib.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))

View File

@ -9,6 +9,23 @@ unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i)
function batch(xs)
data = similar(first(xs), size(first(xs))..., length(xs))
for (i, x) in enumerate(xs)
data[batchindex(data, i)...] = x
end
return data
end
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
function batchseq(xs, pad, n = maximum(length(x) for x in xs))
xs_ = [rpad(x, n, pad) for x in xs]
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
end
# Other # Other
function accuracy(m, data) function accuracy(m, data)