fixes #79
This commit is contained in:
parent
752a9e2808
commit
ccdc046546
|
@ -1,3 +1,5 @@
|
|||
import Base: *
|
||||
|
||||
struct OneHotVector <: AbstractVector{Bool}
|
||||
ix::UInt32
|
||||
of::UInt32
|
||||
|
@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
|||
|
||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
|
||||
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
||||
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||
height::Int
|
||||
|
@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
|||
|
||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||
|
||||
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
|
||||
|
@ -47,3 +49,8 @@ argmax(y::AbstractVector, labels = 1:length(y)) =
|
|||
|
||||
argmax(y::AbstractMatrix, l...) =
|
||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
||||
|
||||
# Ambiguity hack
|
||||
|
||||
a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b))
|
||||
a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
module Tracker
|
||||
|
||||
export TrackedArray, param, back!
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||
|
||||
data(x) = x
|
||||
istracked(x) = false
|
||||
|
|
Loading…
Reference in New Issue