diff --git a/src/onehot.jl b/src/onehot.jl index 488167e2..12a77ecd 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix -Base.getindex(xs::OneHotVector, ::Colon) = xs +Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of) A::AbstractMatrix * b::OneHotVector = A[:, b.ix]