diff --git a/src/training.jl b/src/training.jl index c8deb238..f5b4826b 100644 --- a/src/training.jl +++ b/src/training.jl @@ -12,21 +12,42 @@ function accuracy(m, data) return correct/n end -function train!(m, train, test = []; +""" + @cb for ... end t expr + +Run the for loop, executing `expr` every `t` seconds. +""" +macro cb(ex, t, f) + @assert isexpr(ex, :for) + cond, body = ex.args + @esc t f cond body + :(let + t0 = time_ns() + dt = $t*1e9 + f = () -> $f + @progress $(Expr(:for, cond, quote + t = time_ns() + if t - t0 > dt + t0 = t + f() + end + $body + end)) + end) +end + +function train!(m, train; cb = [], epoch = 1, η = 0.1, loss = mse) - i = 0 @progress for e in 1:epoch info("Epoch $e") - @progress for (x, y) in train + @cb for (x, y) in train x, y = tobatch.((x, y)) - i += 1 ŷ = m(x) any(isnan, ŷ) && error("NaN") Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) update!(m, η) - i % 1000 == 0 && @show accuracy(m, test) - end + end 5 foreach(f -> f(), cb) end return m end