batch and batchseq apis
This commit is contained in:
parent
646720cd05
commit
9a155abecd
@ -22,6 +22,8 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||
|
||||
import NNlib.adapt
|
||||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
17
src/utils.jl
17
src/utils.jl
@ -9,6 +9,23 @@ unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)
|
||||
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
|
||||
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i)
|
||||
|
||||
function batch(xs)
|
||||
data = similar(first(xs), size(first(xs))..., length(xs))
|
||||
for (i, x) in enumerate(xs)
|
||||
data[batchindex(data, i)...] = x
|
||||
end
|
||||
return data
|
||||
end
|
||||
|
||||
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
|
||||
|
||||
function batchseq(xs, pad, n = maximum(length(x) for x in xs))
|
||||
xs_ = [rpad(x, n, pad) for x in xs]
|
||||
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
|
||||
end
|
||||
|
||||
# Other
|
||||
|
||||
function accuracy(m, data)
|
||||
|
Loading…
Reference in New Issue
Block a user