using Juno import Zygote: Context, Params, _forward, gradient # Training step function losscheck(x) x isa Real || error("Function output is not scalar") isinf(x) && error("Loss is infinite") isnan(x) && error("Loss is NaN") end function step!(f, opt, x...) cx = Context() y, ∂f = _forward(cx, f, x...) losscheck(y) f̄ = ∂f(1)[1] # TODO update f ḡ = Globals(cx) update!(opt, nothing, ḡ) return y end # Training loop # Callback niceties call(f, xs...) = f(xs...) runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) struct StopException <: Exception end """ stop() Call `Flux.stop()` in a callback to indicate when a callback condition is met. This would trigger the train loop to stop and exit. ```julia # Example callback: cb = function () accuracy() > 0.9 && Flux.stop() end ``` """ function stop() throw(StopException()) end """ train!(loss, params, data, opt; cb) For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. Takes a callback as keyword argument `cb`. For example, this will print "training" every 10 seconds: ```julia Flux.train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) ``` The callback can call `Flux.stop()` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, ps, data, opt; cb = () -> ()) ps = Params(ps) cb = runall(cb) @progress for d in data try gs = gradient(ps) do loss(d...) end update!(opt, ps, gs) catch ex if ex isa StopException break else rethrow(ex) end end end end """ @epochs N body Run `body` `N` times. Mainly useful for quickly doing multiple epochs of training in a REPL. ```julia julia> @epochs 2 println("hello") INFO: Epoch 1 hello INFO: Epoch 2 hello ``` """ macro epochs(n, ex) :(@progress for i = 1:$(esc(n)) @info "Epoch $i" $(esc(ex)) end) end