72 lines
1.5 KiB
Julia
72 lines
1.5 KiB
Julia
using Juno
|
|
using Flux.Tracker: back!
|
|
include("../utls.jl")
|
|
|
|
runall(f) = f
|
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
|
|
|
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
|
# while training; this just cleans that up.
|
|
macro interrupts(ex)
|
|
:(try $(esc(ex))
|
|
catch e
|
|
e isa InterruptException || rethrow()
|
|
throw(e)
|
|
end)
|
|
end
|
|
|
|
"""
|
|
train!(loss, data, opt)
|
|
|
|
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
|
backpropagation and calls the optimizer `opt`.
|
|
|
|
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
|
every 10 seconds:
|
|
|
|
```julia
|
|
Flux.train!(loss, data, opt,
|
|
cb = throttle(() -> println("training"), 10))
|
|
```
|
|
|
|
The callback can return `:stop` to interrupt the training loop.
|
|
|
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
|
"""
|
|
function train!(loss, data, opt; cb = () -> ())
|
|
cb = try:
|
|
runall(cb)
|
|
catch e
|
|
if e isa StopException || rethrow()
|
|
@info "Stop Condition Met"
|
|
break
|
|
opt = runall(opt)
|
|
@progress for d in data
|
|
l = loss(d...)
|
|
@interrupts back!(l)
|
|
opt()
|
|
cb() == :stop && break
|
|
end
|
|
end
|
|
|
|
"""
|
|
@epochs N body
|
|
|
|
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
|
training in a REPL.
|
|
|
|
```julia
|
|
julia> @epochs 2 println("hello")
|
|
INFO: Epoch 1
|
|
hello
|
|
INFO: Epoch 2
|
|
hello
|
|
```
|
|
"""
|
|
macro epochs(n, ex)
|
|
:(@progress for i = 1:$(esc(n))
|
|
@info "Epoch $i"
|
|
$(esc(ex))
|
|
end)
|
|
end
|