adding tests

This commit is contained in:
Dhairya Gandhi 2019-02-09 22:32:02 +05:30
parent f17a5acd2b
commit 35cd9761a8
3 changed files with 35 additions and 1 deletions

View File

@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.getindex(xs::OneHotVector, ::Colon) = xs
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
@ -22,6 +24,22 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j]
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
# handle special case when we want the whole column
function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray}
res = similar(xs, size(xs, 1), 1)
if length(ot) == size(xs, 1)
res = xs[:,i]
else
res = xs[1:length(ot),i]
end
res
end
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
@ -54,13 +72,17 @@ end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
Base.argmax(xs::OneHotVector) = xs.ix
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data)
function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...)
end

View File

@ -38,6 +38,12 @@ Flux.back!(sum(l))
end
@testset "onecold gpu" begin
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
@test Flux.onecold(y) isa CuArray
@test y[3,:] isa CuArray
end
if CuArrays.libcudnn != nothing
@info "Testing Flux/CUDNN"
include("cudnn.jl")

View File

@ -11,3 +11,9 @@ using Test
@test onecold(a, labels) == 'C'
@test onecold(A, labels) == ['C', 'A', 'D']
end
@testset "onehotbatch indexing" begin
y = Flux.onehotbatch(ones(3), 1:10)
@test y[:,1] isa Flux.OneHotVector
@test y[:,:] isa Flux.OneHotMatrix
end