From 88fa163c95231969a0898470a11c72dc65d4ca2e Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Fri, 21 Jul 2017 16:31:12 +0800 Subject: [PATCH] throttle --- src/training.jl | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/training.jl b/src/training.jl index 62a34c46..3632eb82 100644 --- a/src/training.jl +++ b/src/training.jl @@ -25,18 +25,59 @@ macro cb(ex, t, f) end) end +""" +Returns a function that when invoked, will only be triggered at most once +during `timeout` seconds. Normally, the throttled function will run +as much as it can, without ever going more than once per `wait` duration; +but if you'd like to disable the execution on the leading edge, pass +`leading=false`. To enable execution on the trailing edge, ditto. +""" +function throttle(f, timeout; leading=true, trailing=false) + cooldown = true + later = nothing + + function throttled(args...; kwargs...) + yield() + + if cooldown + if leading + f(args...; kwargs...) + else + later = () -> f(args...; kwargs...) + end + + cooldown = false + @schedule try + while (sleep(timeout); later != nothing) + later() + later = nothing + end + finally + cooldown = true + end + elseif trailing + later = () -> f(args...; kwargs...) + end + + nothing + end +end + function train!(m, train; cb = [], epoch = 1, η = 0.1, loss = mse) + callback = throttle(()->foreach(f -> f(), cb), 5) + @progress for e in 1:epoch info("Epoch $e") - @cb for (x, y) in train + for (x, y) in train x, y = mapt(tobatch, (x, y)) ŷ = m(x) any(isnan, ŷ) && error("NaN") Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) update!(m, η) - end 5 foreach(f -> f(), cb) + callback() + end end return m end