Revert "Fix OneHotVector/Matrix performance on GPU"
This commit is contained in:
parent
e8b2ec6f67
commit
ecc55ec9e1
|
@ -9,8 +9,6 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
|||
|
||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
|
||||
Base.getindex(xs::OneHotVector, ::Colon) = xs
|
||||
|
||||
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||
|
@ -24,21 +22,6 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
|||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
|
||||
Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j]
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs
|
||||
|
||||
# handle special case for when we want the entire column
|
||||
function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray}
|
||||
res = similar(xs, size(xs, 1), 1)
|
||||
if length(ot) == size(xs, 1)
|
||||
res = xs[:,i]
|
||||
else
|
||||
res = xs[1:length(ot),i]
|
||||
end
|
||||
res
|
||||
end
|
||||
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
|
@ -71,17 +54,13 @@ end
|
|||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
|
||||
Base.argmax(xs::OneHotVector) = xs.ix
|
||||
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
onecold(y::AbstractMatrix, labels...) =
|
||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||
|
||||
onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data)
|
||||
|
||||
function argmax(xs...)
|
||||
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
return onecold(xs...)
|
||||
end
|
||||
|
||||
|
|
|
@ -38,13 +38,6 @@ Flux.back!(sum(l))
|
|||
|
||||
end
|
||||
|
||||
@testset "onecold gpu" begin
|
||||
x = zeros(Float32, 10, 3) |> gpu;
|
||||
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
||||
res = Flux.onecold(x) .== Flux.onecold(y)
|
||||
@test res isa CuArray
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
|
|
Loading…
Reference in New Issue