diff --git a/src/onehot.jl b/src/onehot.jl index c82dce23..cd29f14e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,8 +9,6 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix -Base.getindex(xs::OneHotVector, ::Colon) = xs - A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} @@ -24,21 +22,6 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) -Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j] - -Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs - -# handle special case for when we want the entire column -function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray} - res = similar(xs, size(xs, 1), 1) - if length(ot) == size(xs, 1) - res = xs[:,i] - else - res = xs[1:length(ot),i] - end - res -end - A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -71,17 +54,13 @@ end onehotbatch(ls, labels, unk...) = OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) -Base.argmax(xs::OneHotVector) = xs.ix - onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractMatrix, labels...) = dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) -onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data) - function argmax(xs...) - Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax) + Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax) return onecold(xs...) end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 43340c74..f7a08503 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -38,13 +38,6 @@ Flux.back!(sum(l)) end -@testset "onecold gpu" begin - x = zeros(Float32, 10, 3) |> gpu; - y = Flux.onehotbatch(ones(3), 1:10) |> gpu; - res = Flux.onecold(x) .== Flux.onecold(y) - @test res isa CuArray -end - if CuArrays.libcudnn != nothing @info "Testing Flux/CUDNN" include("cudnn.jl")