diff --git a/src/utils.jl b/src/utils.jl index 48af20a7..daac39cf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,3 +15,14 @@ mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) convertel(T::Type, xs::AbstractArray) = convert.(T, xs) convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs + +function accuracy(m, data) + n = 0 + correct = 0 + for (x, y) in data + x, y = tobatch.((x, y)) + n += size(x, 1) + correct += sum(onecold(m(x)) .== onecold(y)) + end + return correct/n +end