fix accuracy for batches
This commit is contained in:
parent
dba6bef245
commit
5dce8df678
@ -2,12 +2,14 @@ tobatch(xs::Batch) = rawbatch(xs)
|
|||||||
tobatch(xs) = tobatch(batchone(xs))
|
tobatch(xs) = tobatch(batchone(xs))
|
||||||
|
|
||||||
function accuracy(m, data)
|
function accuracy(m, data)
|
||||||
|
n = 0
|
||||||
correct = 0
|
correct = 0
|
||||||
for (x, y) in data
|
for (x, y) in data
|
||||||
x, y = tobatch.((x, y))
|
x, y = tobatch.((x, y))
|
||||||
|
n += size(x, 1)
|
||||||
correct += sum(onecold(m(x)) .== onecold(y))
|
correct += sum(onecold(m(x)) .== onecold(y))
|
||||||
end
|
end
|
||||||
return correct/length(data)
|
return correct/n
|
||||||
end
|
end
|
||||||
|
|
||||||
function train!(m, train, test = [];
|
function train!(m, train, test = [];
|
||||||
|
Loading…
Reference in New Issue
Block a user