onehot unk arg
This commit is contained in:
parent
dc1f08a709
commit
2d33f19346
@ -42,7 +42,14 @@ function onehot(l, labels)
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
|
||||
function onehot(l, labels, unk)
|
||||
i = findfirst(labels, l)
|
||||
i > 0 || return onehot(unk, labels)
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
|
||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[findfirst(y, maximum(y))]
|
||||
|
Loading…
Reference in New Issue
Block a user