onehot sanity check

This commit is contained in:
Mike J Innes 2017-10-17 00:07:58 +01:00
parent e02e320008
commit 7aa0b43ceb

View File

@ -34,7 +34,12 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end end
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels)) function onehot(l, labels)
i = findfirst(labels, l)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractVector, labels = 1:length(y)) =