diff --git a/src/onehot.jl b/src/onehot.jl index 754d0607..7a3123ec 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -125,6 +125,4 @@ onecold(y::AbstractMatrix, labels...) = onecold(y::OneHotMatrix, labels...) = mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) -# TODO probably still want this as a custom adjoint Zygote -# onecold(x::TrackedVector, l...) = onecold(data(x), l...) -# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) +@nograd onecold, onehot, onehotbatch