training callbacks
This commit is contained in:
parent
5dce8df678
commit
eae13c533f
@ -12,21 +12,42 @@ function accuracy(m, data)
|
|||||||
return correct/n
|
return correct/n
|
||||||
end
|
end
|
||||||
|
|
||||||
function train!(m, train, test = [];
|
"""
|
||||||
|
@cb for ... end t expr
|
||||||
|
|
||||||
|
Run the for loop, executing `expr` every `t` seconds.
|
||||||
|
"""
|
||||||
|
macro cb(ex, t, f)
|
||||||
|
@assert isexpr(ex, :for)
|
||||||
|
cond, body = ex.args
|
||||||
|
@esc t f cond body
|
||||||
|
:(let
|
||||||
|
t0 = time_ns()
|
||||||
|
dt = $t*1e9
|
||||||
|
f = () -> $f
|
||||||
|
@progress $(Expr(:for, cond, quote
|
||||||
|
t = time_ns()
|
||||||
|
if t - t0 > dt
|
||||||
|
t0 = t
|
||||||
|
f()
|
||||||
|
end
|
||||||
|
$body
|
||||||
|
end))
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
function train!(m, train; cb = [],
|
||||||
epoch = 1, η = 0.1, loss = mse)
|
epoch = 1, η = 0.1, loss = mse)
|
||||||
i = 0
|
|
||||||
@progress for e in 1:epoch
|
@progress for e in 1:epoch
|
||||||
info("Epoch $e")
|
info("Epoch $e")
|
||||||
@progress for (x, y) in train
|
@cb for (x, y) in train
|
||||||
x, y = tobatch.((x, y))
|
x, y = tobatch.((x, y))
|
||||||
i += 1
|
|
||||||
ŷ = m(x)
|
ŷ = m(x)
|
||||||
any(isnan, ŷ) && error("NaN")
|
any(isnan, ŷ) && error("NaN")
|
||||||
Δ = back!(loss, 1, ŷ, y)
|
Δ = back!(loss, 1, ŷ, y)
|
||||||
back!(m, Δ, x)
|
back!(m, Δ, x)
|
||||||
update!(m, η)
|
update!(m, η)
|
||||||
i % 1000 == 0 && @show accuracy(m, test)
|
end 5 foreach(f -> f(), cb)
|
||||||
end
|
|
||||||
end
|
end
|
||||||
return m
|
return m
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user