Merge pull request #755 from Roger-luo/add-more-docs

add some docs for onehot & onecold
This commit is contained in:
Mike J Innes 2019-04-26 11:54:54 +01:00 committed by GitHub
commit b0155ec1fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 55 additions and 0 deletions

View File

@ -39,6 +39,29 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end
"""
onehot(l, labels[, unk])
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
it will error.
## Examples
```jldoctest
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
false
true
false
julia> onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
false
false
true
```
"""
function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
@ -51,9 +74,41 @@ function onehot(l, labels, unk)
OneHotVector(i, length(labels))
end
"""
onehotbatch(ls, labels[, unk...])
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
## Examples
```jldoctest
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix:
false true false
true false true
false false false
```
"""
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
"""
onecold(y[, labels = 1:length(y)])
Inverse operations of [`onehot`](@ref).
## Examples
```jldoctest
julia> onecold([true, false, false], [:a, :b, :c])
:a
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
"""
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
onecold(y::AbstractMatrix, labels...) =