removing non-allocating functions and tests
This commit is contained in:
parent
d16ef75b1c
commit
2ec35861b5
|
@ -24,9 +24,6 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
|||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
|
||||
Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j]
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
||||
|
||||
# handle special case when we want the whole column
|
||||
|
|
|
@ -41,7 +41,6 @@ end
|
|||
@testset "onecold gpu" begin
|
||||
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
||||
@test Flux.onecold(y) isa CuArray
|
||||
@test y[:,:] isa Flux.OneHotMatrix{<:CuArray}
|
||||
@test y[3,:] isa CuArray
|
||||
end
|
||||
|
||||
|
|
|
@ -15,5 +15,4 @@ end
|
|||
@testset "onehotbatch indexing" begin
|
||||
y = Flux.onehotbatch(ones(3), 1:10)
|
||||
@test y[:,1] isa Flux.OneHotVector
|
||||
@test y[:,:] isa Flux.OneHotMatrix
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue