Flux.jl/src/optimise/train.jl
Dhairya Gandhi 756207e782 added docs
2018-08-20 14:20:33 +05:30

94 lines
1.8 KiB
Julia

using Juno
using Flux.Tracker: back!
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
struct StopException <: Exception end
"""
stop()
Call `stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
```julia
# Example callback:
cb = function ()
accuracy() > 0.9 && stop()
end
```
"""
function stop()
throw(StopException())
end
"""
train!(loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
@progress for d in data
l = loss(d...)
@interrupts back!(l)
opt()
try
cb()
catch ex
if ex isa StopException
@info "Stop condition met"
break
else
rethrow(ex)
end
end
end
end
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
```julia
julia> @epochs 2 println("hello")
INFO: Epoch 1
hello
INFO: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
end)
end