diff --git a/src/onehot.jl b/src/onehot.jl index b6cee63d..cd29f14e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -68,3 +68,6 @@ end a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b) a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b) + +onecold(x::TrackedVector, l...) = onecold(data(x), l...) +onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 838317cf..690b0e18 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -358,13 +358,6 @@ x::TrackedVector * y::TrackedVector = track(*, x, y) @grad a::AbstractMatrix * b::AbstractVecOrMat = data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) -# Flux - -import ..Flux.onecold - -onecold(x::TrackedVector, l...) = onecold(data(x), l...) -onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) - # NNlib using NNlib