This commit is contained in:
Mike J Innes 2017-12-13 18:24:56 +00:00
parent 23096824d5
commit 5b97d2ba04
3 changed files with 32 additions and 10 deletions

View File

@ -1,15 +1,24 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: back!, value
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
""" """
train!(loss, data, opt; cb = () -> ()) train!(loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt` and the callback `cb` backpropagation and calls the optimizer `opt`.
(i.e. `opt()` and `cb()`).
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) l = loss(d...)
isinf(l.data[]) && error("Loss is Inf") isinf(value(l)) && error("Loss is Inf")
isnan(l.data[]) && error("Loss is NaN") isnan(value(l)) && error("Loss is NaN")
back!(l) back!(l)
opt() opt()
cb() cb() == :stop && break
end end
end end

View File

@ -95,13 +95,14 @@ but if you'd like to disable the execution on the leading edge, pass
function throttle(f, timeout; leading=true, trailing=false) function throttle(f, timeout; leading=true, trailing=false)
cooldown = true cooldown = true
later = nothing later = nothing
result = nothing
function throttled(args...; kwargs...) function throttled(args...; kwargs...)
yield() yield()
if cooldown if cooldown
if leading if leading
f(args...; kwargs...) result = f(args...; kwargs...)
else else
later = () -> f(args...; kwargs...) later = () -> f(args...; kwargs...)
end end
@ -116,10 +117,10 @@ function throttle(f, timeout; leading=true, trailing=false)
cooldown = true cooldown = true
end end
elseif trailing elseif trailing
later = () -> f(args...; kwargs...) later = () -> (result = f(args...; kwargs...))
end end
nothing return result
end end
end end

View File

@ -15,3 +15,15 @@ using Flux.Tracker
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end
end end
@testset "Training Loop" begin
i = 0
l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100),
()->(),
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
@test 3 < i < 50
end