Flux.jl/src/optimise/train.jl

94 lines
1.8 KiB
Julia
Raw Normal View History

2017-09-03 06:44:32 +00:00
using Juno
using Flux.Tracker: back!
2017-08-31 16:36:18 +00:00
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
2017-09-07 04:27:16 +00:00
2018-03-05 23:05:45 +00:00
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
2018-08-20 08:38:23 +00:00
struct StopException <: Exception end
2018-08-20 08:50:33 +00:00
"""
stop()
Call `stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
```julia
# Example callback:
2018-08-20 08:13:08 +00:00
2018-08-20 08:50:33 +00:00
cb = function ()
accuracy() > 0.9 && stop()
end
```
"""
2018-08-20 08:32:09 +00:00
function stop()
2018-08-20 08:18:28 +00:00
throw(StopException())
end
2018-08-20 08:13:08 +00:00
2017-10-11 11:26:40 +00:00
"""
2017-12-13 18:24:56 +00:00
train!(loss, data, opt)
2017-10-11 11:26:40 +00:00
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
2017-12-13 18:24:56 +00:00
backpropagation and calls the optimizer `opt`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
2017-10-11 11:26:40 +00:00
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
2017-10-11 11:26:40 +00:00
"""
function train!(loss, data, opt; cb = () -> ())
2018-08-19 09:47:07 +00:00
cb = runall(cb)
opt = runall(opt)
2017-10-11 11:26:40 +00:00
@progress for d in data
l = loss(d...)
2018-03-05 23:05:45 +00:00
@interrupts back!(l)
2017-09-07 04:29:55 +00:00
opt()
2018-08-19 09:47:07 +00:00
try
cb()
catch ex
2018-08-19 12:08:00 +00:00
if ex isa StopException
2018-08-19 09:47:07 +00:00
@info "Stop condition met"
break
else
rethrow(ex)
end
2018-08-19 09:49:45 +00:00
end
2017-08-24 10:42:29 +00:00
end
end
2018-03-05 22:56:22 +00:00
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
```julia
julia> @epochs 2 println("hello")
INFO: Epoch 1
hello
INFO: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
2018-08-11 13:42:33 +00:00
@info "Epoch $i"
2018-03-05 22:56:22 +00:00
$(esc(ex))
end)
end