training callbacks

This commit is contained in:
Mike J Innes 2017-05-01 13:43:07 +01:00
parent 5dce8df678
commit eae13c533f

View File

@ -12,21 +12,42 @@ function accuracy(m, data)
return correct/n return correct/n
end end
function train!(m, train, test = []; """
@cb for ... end t expr
Run the for loop, executing `expr` every `t` seconds.
"""
macro cb(ex, t, f)
@assert isexpr(ex, :for)
cond, body = ex.args
@esc t f cond body
:(let
t0 = time_ns()
dt = $t*1e9
f = () -> $f
@progress $(Expr(:for, cond, quote
t = time_ns()
if t - t0 > dt
t0 = t
f()
end
$body
end))
end)
end
function train!(m, train; cb = [],
epoch = 1, η = 0.1, loss = mse) epoch = 1, η = 0.1, loss = mse)
i = 0
@progress for e in 1:epoch @progress for e in 1:epoch
info("Epoch $e") info("Epoch $e")
@progress for (x, y) in train @cb for (x, y) in train
x, y = tobatch.((x, y)) x, y = tobatch.((x, y))
i += 1
= m(x) = m(x)
any(isnan, ) && error("NaN") any(isnan, ) && error("NaN")
Δ = back!(loss, 1, , y) Δ = back!(loss, 1, , y)
back!(m, Δ, x) back!(m, Δ, x)
update!(m, η) update!(m, η)
i % 1000 == 0 && @show accuracy(m, test) end 5 foreach(f -> f(), cb)
end
end end
return m return m
end end