diff --git a/src/onehot.jl b/src/onehot.jl index aea68829..48167a0e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -20,7 +20,7 @@ Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] -Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) +Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) import NNlib.adapt