closes #127
This commit is contained in:
parent
23096824d5
commit
5b97d2ba04
@ -1,15 +1,24 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!
|
using Flux.Tracker: back!, value
|
||||||
|
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, data, opt; cb = () -> ())
|
train!(loss, data, opt)
|
||||||
|
|
||||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||||
backpropagation and calls the optimizer `opt` and the callback `cb`
|
backpropagation and calls the optimizer `opt`.
|
||||||
(i.e. `opt()` and `cb()`).
|
|
||||||
|
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
||||||
|
every 10 seconds:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
Flux.train!(loss, data, opt,
|
||||||
|
cb = throttle(() -> println("training"), 10))
|
||||||
|
```
|
||||||
|
|
||||||
|
The callback can return `:stop` to interrupt the training loop.
|
||||||
|
|
||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
isinf(l.data[]) && error("Loss is Inf")
|
isinf(value(l)) && error("Loss is Inf")
|
||||||
isnan(l.data[]) && error("Loss is NaN")
|
isnan(value(l)) && error("Loss is NaN")
|
||||||
back!(l)
|
back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb()
|
cb() == :stop && break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -95,13 +95,14 @@ but if you'd like to disable the execution on the leading edge, pass
|
|||||||
function throttle(f, timeout; leading=true, trailing=false)
|
function throttle(f, timeout; leading=true, trailing=false)
|
||||||
cooldown = true
|
cooldown = true
|
||||||
later = nothing
|
later = nothing
|
||||||
|
result = nothing
|
||||||
|
|
||||||
function throttled(args...; kwargs...)
|
function throttled(args...; kwargs...)
|
||||||
yield()
|
yield()
|
||||||
|
|
||||||
if cooldown
|
if cooldown
|
||||||
if leading
|
if leading
|
||||||
f(args...; kwargs...)
|
result = f(args...; kwargs...)
|
||||||
else
|
else
|
||||||
later = () -> f(args...; kwargs...)
|
later = () -> f(args...; kwargs...)
|
||||||
end
|
end
|
||||||
@ -116,10 +117,10 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||||||
cooldown = true
|
cooldown = true
|
||||||
end
|
end
|
||||||
elseif trailing
|
elseif trailing
|
||||||
later = () -> f(args...; kwargs...)
|
later = () -> (result = f(args...; kwargs...))
|
||||||
end
|
end
|
||||||
|
|
||||||
nothing
|
return result
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -15,3 +15,15 @@ using Flux.Tracker
|
|||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Training Loop" begin
|
||||||
|
i = 0
|
||||||
|
l = param(1)
|
||||||
|
|
||||||
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
|
Iterators.repeated((), 100),
|
||||||
|
()->(),
|
||||||
|
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
||||||
|
|
||||||
|
@test 3 < i < 50
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user