1021: nograd for onecold, onehot, onehotbatch r=MikeInnes a=CarloLucibello

fixes #1020 

Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
This commit is contained in:
bors[bot] 2020-02-17 14:12:48 +00:00 committed by GitHub
commit e4a84c120f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 3 deletions

View File

@ -125,6 +125,4 @@ onecold(y::AbstractMatrix, labels...) =
onecold(y::OneHotMatrix, labels...) = onecold(y::OneHotMatrix, labels...) =
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
# TODO probably still want this as a custom adjoint Zygote @nograd onecold, onehot, onehotbatch
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)