fix colon indexing
This commit is contained in:
parent
6825639f79
commit
4f1336905f
|
@ -23,6 +23,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
|||
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
# Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
||||
|
||||
|
|
|
@ -15,4 +15,5 @@ end
|
|||
@testset "onehotbatch indexing" begin
|
||||
y = Flux.onehotbatch(ones(3), 1:10)
|
||||
@test y[:,1] isa Flux.OneHotVector
|
||||
@test y[:,:] isa Flux.OneHotMatrix
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue