diff --git a/src/training.jl b/src/training.jl index b0daf211..c8deb238 100644 --- a/src/training.jl +++ b/src/training.jl @@ -2,12 +2,14 @@ tobatch(xs::Batch) = rawbatch(xs) tobatch(xs) = tobatch(batchone(xs)) function accuracy(m, data) + n = 0 correct = 0 for (x, y) in data x, y = tobatch.((x, y)) + n += size(x, 1) correct += sum(onecold(m(x)) .== onecold(y)) end - return correct/length(data) + return correct/n end function train!(m, train, test = [];