From ca1c73ed352d0557a72e12b663fff20332d9aaff Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 24 Jan 2019 11:15:57 +0000 Subject: [PATCH] fixup --- src/onehot.jl | 3 +++ src/tracker/lib/array.jl | 7 ------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index b6cee63d..cd29f14e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -68,3 +68,6 @@ end a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b) a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b) + +onecold(x::TrackedVector, l...) = onecold(data(x), l...) +onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 838317cf..690b0e18 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -358,13 +358,6 @@ x::TrackedVector * y::TrackedVector = track(*, x, y) @grad a::AbstractMatrix * b::AbstractVecOrMat = data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) -# Flux - -import ..Flux.onecold - -onecold(x::TrackedVector, l...) = onecold(data(x), l...) -onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) - # NNlib using NNlib