diff --git a/src/onehot.jl b/src/onehot.jl index 5414773c..f8061063 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,3 +1,5 @@ +import Base: * + struct OneHotVector <: AbstractVector{Bool} ix::UInt32 of::UInt32 @@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix -Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] +A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} height::Int @@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] -Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] +A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -47,3 +49,8 @@ argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractMatrix, l...) = squeeze(mapslices(y -> argmax(y, l...), y, 1), 1) + +# Ambiguity hack + +a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b)) +a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b)) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 5e26a051..3a64fcb7 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,6 +1,6 @@ module Tracker -export TrackedArray, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, param, back! data(x) = x istracked(x) = false