fixes #79
This commit is contained in:
parent
bdf02e42ae
commit
e5d99d784e
@ -1,3 +1,5 @@
|
|||||||
|
import Base: *
|
||||||
|
|
||||||
struct OneHotVector <: AbstractVector{Bool}
|
struct OneHotVector <: AbstractVector{Bool}
|
||||||
ix::UInt32
|
ix::UInt32
|
||||||
of::UInt32
|
of::UInt32
|
||||||
@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
|||||||
|
|
||||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||||
|
|
||||||
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||||
|
|
||||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||||
height::Int
|
height::Int
|
||||||
@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
|||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||||
|
|
||||||
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||||
|
|
||||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||||
|
|
||||||
@ -47,3 +49,8 @@ argmax(y::AbstractVector, labels = 1:length(y)) =
|
|||||||
|
|
||||||
argmax(y::AbstractMatrix, l...) =
|
argmax(y::AbstractMatrix, l...) =
|
||||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
||||||
|
|
||||||
|
# Ambiguity hack
|
||||||
|
|
||||||
|
a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b))
|
||||||
|
a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b))
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
|
|
||||||
export TrackedArray, param, back!
|
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||||
|
|
||||||
data(x) = x
|
data(x) = x
|
||||||
istracked(x) = false
|
istracked(x) = false
|
||||||
|
Loading…
Reference in New Issue
Block a user