using Juno import Zygote: Params, gradient function update!(x::AbstractArray, x̄) x .+= x̄ return x end function update!(opt, x, x̄) update!(x, -apply!(opt, x, x̄)) end function update!(opt, xs::Params, gs) for x in xs update!(opt, x, gs[x]) end end # Callback niceties call(f, xs...) = f(xs...) runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) struct StopException <: Exception end """ stop() Call `Flux.stop()` in a callback to indicate when a callback condition is met. This would trigger the train loop to stop and exit. ```julia # Example callback: cb = function () accuracy() > 0.9 && Flux.stop() end ``` """ function stop() throw(StopException()) end """ train!(loss, params, data, opt; cb) For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. Takes a callback as keyword argument `cb`. For example, this will print "training" every 10 seconds: ```julia Flux.train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) ``` The callback can call `Flux.stop()` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, ps, data, opt; cb = () -> ()) ps = Params(ps) cb = runall(cb) @progress for d in data try gs = gradient(ps) do loss(d...) end update!(opt, ps, gs) catch ex if ex isa StopException break else rethrow(ex) end end end end """ @epochs N body Run `body` `N` times. Mainly useful for quickly doing multiple epochs of training in a REPL. ```julia julia> @epochs 2 println("hello") INFO: Epoch 1 hello INFO: Epoch 2 hello ``` """ macro epochs(n, ex) :(@progress for i = 1:$(esc(n)) @info "Epoch $i" $(esc(ex)) end) end