diff --git a/src/onehot.jl b/src/onehot.jl index f8061063..f94fb93e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -42,7 +42,14 @@ function onehot(l, labels) OneHotVector(i, length(labels)) end -onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) +function onehot(l, labels, unk) + i = findfirst(labels, l) + i > 0 || return onehot(unk, labels) + OneHotVector(i, length(labels)) +end + +onehotbatch(ls, labels, unk...) = + OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) argmax(y::AbstractVector, labels = 1:length(y)) = labels[findfirst(y, maximum(y))]