From 88fa163c95231969a0898470a11c72dc65d4ca2e Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Fri, 21 Jul 2017 16:31:12 +0800 Subject: [PATCH 1/2] throttle --- src/training.jl | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/training.jl b/src/training.jl index 62a34c46..3632eb82 100644 --- a/src/training.jl +++ b/src/training.jl @@ -25,18 +25,59 @@ macro cb(ex, t, f) end) end +""" +Returns a function that when invoked, will only be triggered at most once +during `timeout` seconds. Normally, the throttled function will run +as much as it can, without ever going more than once per `wait` duration; +but if you'd like to disable the execution on the leading edge, pass +`leading=false`. To enable execution on the trailing edge, ditto. +""" +function throttle(f, timeout; leading=true, trailing=false) + cooldown = true + later = nothing + + function throttled(args...; kwargs...) + yield() + + if cooldown + if leading + f(args...; kwargs...) + else + later = () -> f(args...; kwargs...) + end + + cooldown = false + @schedule try + while (sleep(timeout); later != nothing) + later() + later = nothing + end + finally + cooldown = true + end + elseif trailing + later = () -> f(args...; kwargs...) + end + + nothing + end +end + function train!(m, train; cb = [], epoch = 1, η = 0.1, loss = mse) + callback = throttle(()->foreach(f -> f(), cb), 5) + @progress for e in 1:epoch info("Epoch $e") - @cb for (x, y) in train + for (x, y) in train x, y = mapt(tobatch, (x, y)) ŷ = m(x) any(isnan, ŷ) && error("NaN") Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) update!(m, η) - end 5 foreach(f -> f(), cb) + callback() + end end return m end From cce1a2a73e4155a6233822380c26d6c063f72de1 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Wed, 26 Jul 2017 09:57:20 +0800 Subject: [PATCH 2/2] add tests for throttle --- test/runtests.jl | 1 + test/throttle.jl | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 test/throttle.jl diff --git a/test/runtests.jl b/test/runtests.jl index 655d924a..7128e13b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ include("backend/common.jl") include("basic.jl") include("recurrent.jl") include("optimizer.jl") +include("throttle.jl") @tfonly include("backend/tensorflow.jl") @mxonly include("backend/mxnet.jl") diff --git a/test/throttle.jl b/test/throttle.jl new file mode 100644 index 00000000..d3dcd925 --- /dev/null +++ b/test/throttle.jl @@ -0,0 +1,49 @@ +using Flux.throttle + +@testset "throttle" begin + @testset "default behaviour" begin + a = [] + f = throttle(()->push!(a, now()), 1, leading=true, trailing=false) + f() + f() + f() + sleep(1.01) + @test length(a) == 1 + end + + @testset "leading behaviour" begin + a = [] + f = throttle(()->push!(a, now()), 1, leading=true, trailing=false) + f() + @test length(a) == 1 + f() + @test length(a) == 1 + sleep(1.01) + f() + @test length(a) == 2 + end + + @testset "trailing behaviour" begin + a = [] + f = throttle(()->push!(a, now()), 1, leading=false, trailing=true) + f() + @test length(a) == 0 + f() + @test length(a) == 0 + sleep(1.01) + @test length(a) == 1 + end + + @testset "arguments" begin + a = [] + f = throttle((x)->push!(a, x), 1, leading=true, trailing=true) + f(1) + @test a == [1] + f(2) + @test a == [1] + f(3) + @test a == [1] + sleep(1.01) + @test a == [1, 3] + end +end