Flux.jl/src/optimise/train.jl

58 lines
1.2 KiB
Julia
Raw Normal View History

2017-09-03 06:44:32 +00:00
using Juno
using Flux.Tracker: back!
2017-08-31 16:36:18 +00:00
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
2017-09-07 04:27:16 +00:00
2017-10-11 11:26:40 +00:00
"""
2017-12-13 18:24:56 +00:00
train!(loss, data, opt)
2017-10-11 11:26:40 +00:00
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
2017-12-13 18:24:56 +00:00
backpropagation and calls the optimizer `opt`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
2017-10-11 11:26:40 +00:00
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
2017-10-11 11:26:40 +00:00
"""
function train!(loss, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
2017-10-11 11:26:40 +00:00
@progress for d in data
l = loss(d...)
2018-02-07 20:39:36 +00:00
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
2017-09-12 13:11:03 +00:00
back!(l)
2017-09-07 04:29:55 +00:00
opt()
2017-12-13 18:24:56 +00:00
cb() == :stop && break
2017-08-24 10:42:29 +00:00
end
end
2018-03-05 22:56:22 +00:00
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
```julia
julia> @epochs 2 println("hello")
INFO: Epoch 1
hello
INFO: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
info("Epoch $i")
$(esc(ex))
end)
end