onecold fix

This commit is contained in:
Mike J Innes 2019-01-24 10:16:41 +00:00
parent 1eee724054
commit 62d780c77f
1 changed files with 7 additions and 0 deletions

View File

@ -358,6 +358,13 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
# Flux
import ..Flux.onecold
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
# NNlib
using NNlib