Merge pull request #755 from Roger-luo/add-more-docs
add some docs for onehot & onecold
This commit is contained in:
commit
b0155ec1fe
@ -39,6 +39,29 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
|
|||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
onehot(l, labels[, unk])
|
||||||
|
|
||||||
|
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
|
||||||
|
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
|
||||||
|
it will error.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
|
3-element Flux.OneHotVector:
|
||||||
|
false
|
||||||
|
true
|
||||||
|
false
|
||||||
|
|
||||||
|
julia> onehot(:c, [:a, :b, :c])
|
||||||
|
3-element Flux.OneHotVector:
|
||||||
|
false
|
||||||
|
false
|
||||||
|
true
|
||||||
|
```
|
||||||
|
"""
|
||||||
function onehot(l, labels)
|
function onehot(l, labels)
|
||||||
i = something(findfirst(isequal(l), labels), 0)
|
i = something(findfirst(isequal(l), labels), 0)
|
||||||
i > 0 || error("Value $l is not in labels")
|
i > 0 || error("Value $l is not in labels")
|
||||||
@ -51,9 +74,41 @@ function onehot(l, labels, unk)
|
|||||||
OneHotVector(i, length(labels))
|
OneHotVector(i, length(labels))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
onehotbatch(ls, labels[, unk...])
|
||||||
|
|
||||||
|
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
|
||||||
|
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||||
|
3×3 Flux.OneHotMatrix:
|
||||||
|
false true false
|
||||||
|
true false true
|
||||||
|
false false false
|
||||||
|
|
||||||
|
```
|
||||||
|
"""
|
||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
|
"""
|
||||||
|
onecold(y[, labels = 1:length(y)])
|
||||||
|
|
||||||
|
Inverse operations of [`onehot`](@ref).
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```jldoctest
|
||||||
|
julia> onecold([true, false, false], [:a, :b, :c])
|
||||||
|
:a
|
||||||
|
|
||||||
|
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||||
|
:c
|
||||||
|
```
|
||||||
|
"""
|
||||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||||
|
|
||||||
onecold(y::AbstractMatrix, labels...) =
|
onecold(y::AbstractMatrix, labels...) =
|
||||||
|
Loading…
Reference in New Issue
Block a user