commit
953280d57f
@ -3,7 +3,7 @@
|
|||||||
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
||||||
|
|
||||||
```
|
```
|
||||||
julia> using Flux: onehot
|
julia> using Flux: onehot, onecold
|
||||||
|
|
||||||
julia> onehot(:b, [:a, :b, :c])
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
|
|||||||
true
|
true
|
||||||
```
|
```
|
||||||
|
|
||||||
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
|
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> argmax(ans, [:a, :b, :c])
|
julia> onecold(ans, [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
|
|
||||||
julia> argmax([true, false, false], [:a, :b, :c])
|
julia> onecold([true, false, false], [:a, :b, :c])
|
||||||
:a
|
:a
|
||||||
|
|
||||||
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
|
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
```
|
```
|
||||||
|
|
||||||
## Batches
|
## Batches
|
||||||
|
|
||||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
|
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> using Flux: onehotbatch
|
julia> using Flux: onehotbatch
|
||||||
|
@ -54,11 +54,15 @@ end
|
|||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||||
labels[findfirst(y, maximum(y))]
|
|
||||||
|
|
||||||
argmax(y::AbstractMatrix, l...) =
|
onecold(y::AbstractMatrix, labels...) =
|
||||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||||
|
|
||||||
|
function argmax(xs...)
|
||||||
|
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||||
|
return onecold(xs...)
|
||||||
|
end
|
||||||
|
|
||||||
# Ambiguity hack
|
# Ambiguity hack
|
||||||
|
|
||||||
|
13
test/onehot.jl
Normal file
13
test/onehot.jl
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
using Flux:onecold
|
||||||
|
using Test
|
||||||
|
|
||||||
|
@testset "onecold" begin
|
||||||
|
a = [1, 2, 5, 3.]
|
||||||
|
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
|
||||||
|
labels = ['A', 'B', 'C', 'D']
|
||||||
|
|
||||||
|
@test onecold(a) == 3
|
||||||
|
@test onecold(A) == [3, 1, 4]
|
||||||
|
@test onecold(a, labels) == 'C'
|
||||||
|
@test onecold(A, labels) == ['C', 'A', 'D']
|
||||||
|
end
|
@ -24,6 +24,7 @@ insert!(LOAD_PATH, 2, "@v#.#")
|
|||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
include("onehot.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
|
Loading…
Reference in New Issue
Block a user