diff --git a/src/data.jl b/src/data.jl index 87e87b94..be5497b6 100644 --- a/src/data.jl +++ b/src/data.jl @@ -1,7 +1,23 @@ export onehot, onecold, chunk, partition, batches, sequences +""" + onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false] + + onehot(Float32, 'c', ['a', 'b', 'c', 'd']) => [0., 0., 1., 0.] + +Produce a one-hot-encoded version of an item, given a list of possible values +for the item. +""" onehot(T::Type, label, labels) = T[i == label for i in labels] onehot(label, labels) = onehot(Int, label, labels) + +""" + onecold([0.0, 1.0, 0.0, ...], + ['a', 'b', 'c', ...]) => 'b' + +The inverse of `onehot`; takes an output prediction vector and a list of +possible values, and produces the appropriate value. +""" onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))] using Iterators