store onehotmatrix height
This commit is contained in:
parent
d3419c943b
commit
1b91e6b38d
@ -10,10 +10,11 @@ Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||
height::Int
|
||||
data::A
|
||||
end
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data))
|
||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||
|
||||
@ -23,13 +24,13 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||
|
||||
@require CuArrays begin
|
||||
import CuArrays: CuArray, cudaconvert
|
||||
CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(CuArrays.cu(xs.data))
|
||||
CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(xs.height, CuArrays.cu(xs.data))
|
||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(cudaconvert(x.data))
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
||||
onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls])
|
||||
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
|
||||
|
||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[findfirst(y, maximum(y))]
|
||||
|
Loading…
Reference in New Issue
Block a user