Merge pull request #755 from Roger-luo/add-more-docs
add some docs for onehot & onecold
This commit is contained in:
commit
b0155ec1fe
|
@ -39,6 +39,29 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
|
|||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
"""
|
||||
onehot(l, labels[, unk])
|
||||
|
||||
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
|
||||
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
|
||||
it will error.
|
||||
|
||||
## Examples
|
||||
|
||||
```jldoctest
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
false
|
||||
true
|
||||
false
|
||||
|
||||
julia> onehot(:c, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
false
|
||||
false
|
||||
true
|
||||
```
|
||||
"""
|
||||
function onehot(l, labels)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
|
@ -51,9 +74,41 @@ function onehot(l, labels, unk)
|
|||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
"""
|
||||
onehotbatch(ls, labels[, unk...])
|
||||
|
||||
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
|
||||
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
|
||||
|
||||
## Examples
|
||||
|
||||
```jldoctest
|
||||
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||
3×3 Flux.OneHotMatrix:
|
||||
false true false
|
||||
true false true
|
||||
false false false
|
||||
|
||||
```
|
||||
"""
|
||||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
|
||||
"""
|
||||
onecold(y[, labels = 1:length(y)])
|
||||
|
||||
Inverse operations of [`onehot`](@ref).
|
||||
|
||||
## Examples
|
||||
|
||||
```jldoctest
|
||||
julia> onecold([true, false, false], [:a, :b, :c])
|
||||
:a
|
||||
|
||||
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
:c
|
||||
```
|
||||
"""
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
onecold(y::AbstractMatrix, labels...) =
|
||||
|
|
Loading…
Reference in New Issue