fix accuracy for batches

This commit is contained in:
Mike J Innes 2017-05-01 13:40:11 +01:00
parent dba6bef245
commit 5dce8df678

View File

@ -2,12 +2,14 @@ tobatch(xs::Batch) = rawbatch(xs)
tobatch(xs) = tobatch(batchone(xs)) tobatch(xs) = tobatch(batchone(xs))
function accuracy(m, data) function accuracy(m, data)
n = 0
correct = 0 correct = 0
for (x, y) in data for (x, y) in data
x, y = tobatch.((x, y)) x, y = tobatch.((x, y))
n += size(x, 1)
correct += sum(onecold(m(x)) .== onecold(y)) correct += sum(onecold(m(x)) .== onecold(y))
end end
return correct/length(data) return correct/n
end end
function train!(m, train, test = []; function train!(m, train, test = [];