Merge pull request #52 from ylxdzsw/throttle

implement the training callback with `throttle` function
This commit is contained in:
Mike J Innes 2017-08-17 14:09:45 +01:00 committed by GitHub
commit a1b7434599
3 changed files with 93 additions and 2 deletions

View File

@ -25,18 +25,59 @@ macro cb(ex, t, f)
end)
end
"""
Returns a function that when invoked, will only be triggered at most once
during `timeout` seconds. Normally, the throttled function will run
as much as it can, without ever going more than once per `wait` duration;
but if you'd like to disable the execution on the leading edge, pass
`leading=false`. To enable execution on the trailing edge, ditto.
"""
function throttle(f, timeout; leading=true, trailing=false)
cooldown = true
later = nothing
function throttled(args...; kwargs...)
yield()
if cooldown
if leading
f(args...; kwargs...)
else
later = () -> f(args...; kwargs...)
end
cooldown = false
@schedule try
while (sleep(timeout); later != nothing)
later()
later = nothing
end
finally
cooldown = true
end
elseif trailing
later = () -> f(args...; kwargs...)
end
nothing
end
end
function train!(m, train; cb = [],
epoch = 1, η = 0.1, loss = mse)
callback = throttle(()->foreach(f -> f(), cb), 5)
@progress for e in 1:epoch
info("Epoch $e")
@cb for (x, y) in train
for (x, y) in train
x, y = mapt(tobatch, (x, y))
= m(x)
any(isnan, ) && error("NaN")
Δ = back!(loss, 1, , y)
back!(m, Δ, x)
update!(m, η)
end 5 foreach(f -> f(), cb)
callback()
end
end
return m
end

View File

@ -18,6 +18,7 @@ include("backend/common.jl")
include("basic.jl")
include("recurrent.jl")
include("optimizer.jl")
include("throttle.jl")
@tfonly include("backend/tensorflow.jl")
@mxonly include("backend/mxnet.jl")

49
test/throttle.jl Normal file
View File

@ -0,0 +1,49 @@
using Flux.throttle
@testset "throttle" begin
@testset "default behaviour" begin
a = []
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
f()
f()
f()
sleep(1.01)
@test length(a) == 1
end
@testset "leading behaviour" begin
a = []
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
f()
@test length(a) == 1
f()
@test length(a) == 1
sleep(1.01)
f()
@test length(a) == 2
end
@testset "trailing behaviour" begin
a = []
f = throttle(()->push!(a, now()), 1, leading=false, trailing=true)
f()
@test length(a) == 0
f()
@test length(a) == 0
sleep(1.01)
@test length(a) == 1
end
@testset "arguments" begin
a = []
f = throttle((x)->push!(a, x), 1, leading=true, trailing=true)
f(1)
@test a == [1]
f(2)
@test a == [1]
f(3)
@test a == [1]
sleep(1.01)
@test a == [1, 3]
end
end