using Juno using Flux.Tracker: back! runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) # The AD generates fairly large backtraces that are unhelpful if you interrupt # while training; this just cleans that up. macro interrupts(ex) :(try $(esc(ex)) catch e e isa InterruptException || rethrow() throw(e) end) end struct StopException <: Exception x::Symbol end function stop(x) throw(StopException()) end """ train!(loss, data, opt) For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. Takes a callback as keyword argument `cb`. For example, this will print "training" every 10 seconds: ```julia Flux.train!(loss, data, opt, cb = throttle(() -> println("training"), 10)) ``` The callback can return `:stop` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) @progress for d in data l = loss(d...) @interrupts back!(l) opt() try cb() catch ex if ex isa StopException @info "Stop condition met" break else println(ex) rethrow(ex) end end end end """ @epochs N body Run `body` `N` times. Mainly useful for quickly doing multiple epochs of training in a REPL. ```julia julia> @epochs 2 println("hello") INFO: Epoch 1 hello INFO: Epoch 2 hello ``` """ macro epochs(n, ex) :(@progress for i = 1:$(esc(n)) @info "Epoch $i" $(esc(ex)) end) end