From 6499344af397db698706b3325d2ba6831178ac65 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 6 Feb 2020 13:42:17 +0100 Subject: [PATCH] nograd for onecold, onehot, onehotbatch --- src/onehot.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 754d0607..7a3123ec 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -125,6 +125,4 @@ onecold(y::AbstractMatrix, labels...) = onecold(y::OneHotMatrix, labels...) = mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) -# TODO probably still want this as a custom adjoint Zygote -# onecold(x::TrackedVector, l...) = onecold(data(x), l...) -# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) +@nograd onecold, onehot, onehotbatch