more onehot indexing

This commit is contained in:
Mike J Innes 2017-12-15 16:17:39 +00:00
parent 9d0dd9fb7e
commit 9b833a4345

View File

@ -18,7 +18,9 @@ end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]