Flux.jl/src/onehot.jl

45 lines
1.3 KiB
Julia
Raw Normal View History

2017-09-06 22:58:55 +00:00
struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
end
Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
2017-09-27 20:58:34 +00:00
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
2017-10-02 19:50:11 +00:00
height::Int
2017-09-27 20:58:34 +00:00
data::A
2017-09-06 22:58:55 +00:00
end
2017-10-02 19:50:11 +00:00
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
2017-09-06 22:58:55 +00:00
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
2017-10-15 22:44:16 +00:00
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
2017-09-06 22:58:55 +00:00
2017-10-15 22:44:40 +00:00
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
2017-10-04 17:55:56 +00:00
import NNlib.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
2017-09-27 20:58:34 +00:00
@require CuArrays begin
2017-09-27 21:51:00 +00:00
import CuArrays: CuArray, cudaconvert
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
2017-10-02 19:50:11 +00:00
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
2017-09-27 20:58:34 +00:00
end
2017-09-06 22:58:55 +00:00
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
2017-10-02 19:50:11 +00:00
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
2017-09-06 22:58:55 +00:00
2017-09-11 12:40:11 +00:00
argmax(y::AbstractVector, labels = 1:length(y)) =
2017-09-06 22:58:55 +00:00
labels[findfirst(y, maximum(y))]
2017-09-11 12:40:11 +00:00
argmax(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)