tweaks
This commit is contained in:
parent
6c97846551
commit
dcde6d2217
|
@ -3,7 +3,7 @@
|
|||
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
||||
|
||||
```
|
||||
julia> using Flux: onehot
|
||||
julia> using Flux: onehot, onecold
|
||||
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
|
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
|
|||
true
|
||||
```
|
||||
|
||||
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
|
||||
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||
|
||||
```julia
|
||||
julia> argmax(ans, [:a, :b, :c])
|
||||
julia> onecold(ans, [:a, :b, :c])
|
||||
:c
|
||||
|
||||
julia> argmax([true, false, false], [:a, :b, :c])
|
||||
julia> onecold([true, false, false], [:a, :b, :c])
|
||||
:a
|
||||
|
||||
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
:c
|
||||
```
|
||||
|
||||
## Batches
|
||||
|
||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
|
||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||
|
||||
```julia
|
||||
julia> using Flux: onehotbatch
|
||||
|
|
|
@ -54,16 +54,15 @@ end
|
|||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
|
||||
import Base:argmax
|
||||
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[something(findfirst(isequal(maximum(y)), y), 0)]
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
onecold(y::AbstractMatrix, labels...) =
|
||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||
|
||||
@deprecate argmax(y::AbstractVector, labels::AbstractVector) onecold(y, labels)
|
||||
@deprecate argmax(y::AbstractMatrix, labels::AbstractVector) onecold(y, labels)
|
||||
function argmax(xs...)
|
||||
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
return onecold(xs...)
|
||||
end
|
||||
|
||||
# Ambiguity hack
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
using Flux:onecold
|
||||
using Test
|
||||
|
||||
@testset "argmax" begin
|
||||
@testset "onecold" begin
|
||||
a = [1, 2, 5, 3.]
|
||||
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
|
||||
labels = ['A', 'B', 'C', 'D']
|
||||
|
||||
|
||||
@test onecold(a) == 3
|
||||
@test onecold(A) == [3, 1, 4]
|
||||
@test onecold(a, labels) == 'C'
|
||||
|
|
Loading…
Reference in New Issue