onehot unk arg

This commit is contained in:
Mike J Innes 2017-11-29 16:45:50 +00:00
parent dc1f08a709
commit 2d33f19346

View File

@ -42,7 +42,14 @@ function onehot(l, labels)
OneHotVector(i, length(labels)) OneHotVector(i, length(labels))
end end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) function onehot(l, labels, unk)
i = findfirst(labels, l)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))] labels[findfirst(y, maximum(y))]