Flux.jl/src/optimise/train.jl

66 lines
1.4 KiB
Julia
Raw Normal View History

2017-09-03 06:44:32 +00:00
using Juno
using Flux.Tracker: back!
2017-08-31 16:36:18 +00:00
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
2017-09-07 04:27:16 +00:00
2018-03-05 23:05:45 +00:00
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
2017-10-11 11:26:40 +00:00
"""
2017-12-13 18:24:56 +00:00
train!(loss, data, opt)
2017-10-11 11:26:40 +00:00
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
2017-12-13 18:24:56 +00:00
backpropagation and calls the optimizer `opt`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
2017-10-11 11:26:40 +00:00
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
2017-10-11 11:26:40 +00:00
"""
function train!(loss, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
2017-10-11 11:26:40 +00:00
@progress for d in data
l = loss(d...)
2018-03-05 23:05:45 +00:00
@interrupts back!(l)
2017-09-07 04:29:55 +00:00
opt()
2017-12-13 18:24:56 +00:00
cb() == :stop && break
2017-08-24 10:42:29 +00:00
end
end
2018-03-05 22:56:22 +00:00
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
```julia
julia> @epochs 2 println("hello")
INFO: Epoch 1
hello
INFO: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
info("Epoch $i")
$(esc(ex))
end)
end