diff --git a/src/onehot.jl b/src/onehot.jl index 12a77ecd..0cd145f8 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -23,7 +23,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) -# Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) +Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)