Merge pull request #612 from dhairyagandhi96/onecold
Fixes OneHotMatrix/Vector GPU Performance
This commit is contained in:
commit
eff600642a
@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
|||||||
|
|
||||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||||
|
|
||||||
|
Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)
|
||||||
|
|
||||||
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||||
|
|
||||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||||
@ -18,9 +20,12 @@ end
|
|||||||
|
|
||||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||||
|
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
|
||||||
|
|
||||||
|
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
||||||
|
|
||||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||||
|
|
||||||
@ -94,6 +99,8 @@ julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
|||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
|
Base.argmax(xs::OneHotVector) = xs.ix
|
||||||
|
|
||||||
"""
|
"""
|
||||||
onecold(y[, labels = 1:length(y)])
|
onecold(y[, labels = 1:length(y)])
|
||||||
|
|
||||||
@ -114,8 +121,11 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
|||||||
onecold(y::AbstractMatrix, labels...) =
|
onecold(y::AbstractMatrix, labels...) =
|
||||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||||
|
|
||||||
|
onecold(y::OneHotMatrix, labels...) =
|
||||||
|
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
|
||||||
|
|
||||||
function argmax(xs...)
|
function argmax(xs...)
|
||||||
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
|
||||||
return onecold(xs...)
|
return onecold(xs...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -38,6 +38,12 @@ Flux.back!(sum(l))
|
|||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "onecold gpu" begin
|
||||||
|
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
||||||
|
@test Flux.onecold(y) isa CuArray
|
||||||
|
@test y[3,:] isa CuArray
|
||||||
|
end
|
||||||
|
|
||||||
if CuArrays.libcudnn != nothing
|
if CuArrays.libcudnn != nothing
|
||||||
@info "Testing Flux/CUDNN"
|
@info "Testing Flux/CUDNN"
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
|
@ -11,3 +11,9 @@ using Test
|
|||||||
@test onecold(a, labels) == 'C'
|
@test onecold(a, labels) == 'C'
|
||||||
@test onecold(A, labels) == ['C', 'A', 'D']
|
@test onecold(A, labels) == ['C', 'A', 'D']
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "onehotbatch indexing" begin
|
||||||
|
y = Flux.onehotbatch(ones(3), 1:10)
|
||||||
|
@test y[:,1] isa Flux.OneHotVector
|
||||||
|
@test y[:,:] isa Flux.OneHotMatrix
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user