Merge pull request #52 from ylxdzsw/throttle
implement the training callback with `throttle` function
This commit is contained in:
commit
a1b7434599
@ -25,18 +25,59 @@ macro cb(ex, t, f)
|
||||
end)
|
||||
end
|
||||
|
||||
"""
|
||||
Returns a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds. Normally, the throttled function will run
|
||||
as much as it can, without ever going more than once per `wait` duration;
|
||||
but if you'd like to disable the execution on the leading edge, pass
|
||||
`leading=false`. To enable execution on the trailing edge, ditto.
|
||||
"""
|
||||
function throttle(f, timeout; leading=true, trailing=false)
|
||||
cooldown = true
|
||||
later = nothing
|
||||
|
||||
function throttled(args...; kwargs...)
|
||||
yield()
|
||||
|
||||
if cooldown
|
||||
if leading
|
||||
f(args...; kwargs...)
|
||||
else
|
||||
later = () -> f(args...; kwargs...)
|
||||
end
|
||||
|
||||
cooldown = false
|
||||
@schedule try
|
||||
while (sleep(timeout); later != nothing)
|
||||
later()
|
||||
later = nothing
|
||||
end
|
||||
finally
|
||||
cooldown = true
|
||||
end
|
||||
elseif trailing
|
||||
later = () -> f(args...; kwargs...)
|
||||
end
|
||||
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
||||
function train!(m, train; cb = [],
|
||||
epoch = 1, η = 0.1, loss = mse)
|
||||
callback = throttle(()->foreach(f -> f(), cb), 5)
|
||||
|
||||
@progress for e in 1:epoch
|
||||
info("Epoch $e")
|
||||
@cb for (x, y) in train
|
||||
for (x, y) in train
|
||||
x, y = mapt(tobatch, (x, y))
|
||||
ŷ = m(x)
|
||||
any(isnan, ŷ) && error("NaN")
|
||||
Δ = back!(loss, 1, ŷ, y)
|
||||
back!(m, Δ, x)
|
||||
update!(m, η)
|
||||
end 5 foreach(f -> f(), cb)
|
||||
callback()
|
||||
end
|
||||
end
|
||||
return m
|
||||
end
|
||||
|
@ -18,6 +18,7 @@ include("backend/common.jl")
|
||||
include("basic.jl")
|
||||
include("recurrent.jl")
|
||||
include("optimizer.jl")
|
||||
include("throttle.jl")
|
||||
|
||||
@tfonly include("backend/tensorflow.jl")
|
||||
@mxonly include("backend/mxnet.jl")
|
||||
|
49
test/throttle.jl
Normal file
49
test/throttle.jl
Normal file
@ -0,0 +1,49 @@
|
||||
using Flux.throttle
|
||||
|
||||
@testset "throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
||||
f()
|
||||
f()
|
||||
f()
|
||||
sleep(1.01)
|
||||
@test length(a) == 1
|
||||
end
|
||||
|
||||
@testset "leading behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
||||
f()
|
||||
@test length(a) == 1
|
||||
f()
|
||||
@test length(a) == 1
|
||||
sleep(1.01)
|
||||
f()
|
||||
@test length(a) == 2
|
||||
end
|
||||
|
||||
@testset "trailing behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=false, trailing=true)
|
||||
f()
|
||||
@test length(a) == 0
|
||||
f()
|
||||
@test length(a) == 0
|
||||
sleep(1.01)
|
||||
@test length(a) == 1
|
||||
end
|
||||
|
||||
@testset "arguments" begin
|
||||
a = []
|
||||
f = throttle((x)->push!(a, x), 1, leading=true, trailing=true)
|
||||
f(1)
|
||||
@test a == [1]
|
||||
f(2)
|
||||
@test a == [1]
|
||||
f(3)
|
||||
@test a == [1]
|
||||
sleep(1.01)
|
||||
@test a == [1, 3]
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue
Block a user