From 7aa0b43ceb12876b7cf542d687b3c098c605d0d0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 17 Oct 2017 00:07:58 +0100 Subject: [PATCH] onehot sanity check --- src/onehot.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 0bd694ef..5414773c 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -34,7 +34,12 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) end -onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels)) +function onehot(l, labels) + i = findfirst(labels, l) + i > 0 || error("Value $l is not in labels") + OneHotVector(i, length(labels)) +end + onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) argmax(y::AbstractVector, labels = 1:length(y)) =