diff --git a/src/onehot.jl b/src/onehot.jl index f94fb93e..4f121958 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -18,7 +18,9 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]