From 9b833a434525bc7afc00dd95c3799b71784f84d1 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 15 Dec 2017 16:17:39 +0000 Subject: [PATCH] more onehot indexing --- src/onehot.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index f94fb93e..4f121958 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -18,7 +18,9 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]