105 lines
2.0 KiB
Julia
105 lines
2.0 KiB
Julia
using Juno
|
|
import Zygote: Context, Params, _forward, gradient
|
|
|
|
# Training step
|
|
|
|
function losscheck(x)
|
|
x isa Real || error("Function output is not scalar")
|
|
isinf(x) && error("Loss is infinite")
|
|
isnan(x) && error("Loss is NaN")
|
|
end
|
|
|
|
function step!(f, opt, x...)
|
|
cx = Context()
|
|
y, ∂f = _forward(cx, f, x...)
|
|
losscheck(y)
|
|
f̄ = ∂f(1)[1] # TODO update f
|
|
ḡ = Globals(cx)
|
|
update!(opt, nothing, ḡ)
|
|
return y
|
|
end
|
|
|
|
# Training loop
|
|
|
|
# Callback niceties
|
|
call(f, xs...) = f(xs...)
|
|
runall(f) = f
|
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
|
|
|
struct StopException <: Exception end
|
|
"""
|
|
stop()
|
|
|
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
|
This would trigger the train loop to stop and exit.
|
|
|
|
```julia
|
|
# Example callback:
|
|
|
|
cb = function ()
|
|
accuracy() > 0.9 && Flux.stop()
|
|
end
|
|
```
|
|
"""
|
|
function stop()
|
|
throw(StopException())
|
|
end
|
|
|
|
"""
|
|
train!(loss, params, data, opt; cb)
|
|
|
|
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
|
backpropagation and calls the optimizer `opt`.
|
|
|
|
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
|
every 10 seconds:
|
|
|
|
```julia
|
|
Flux.train!(loss, params, data, opt,
|
|
cb = throttle(() -> println("training"), 10))
|
|
```
|
|
|
|
The callback can call `Flux.stop()` to interrupt the training loop.
|
|
|
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
|
"""
|
|
function train!(loss, ps, data, opt; cb = () -> ())
|
|
ps = Params(ps)
|
|
cb = runall(cb)
|
|
@progress for d in data
|
|
try
|
|
gs = gradient(ps) do
|
|
loss(d...)
|
|
end
|
|
update!(opt, ps, gs)
|
|
catch ex
|
|
if ex isa StopException
|
|
break
|
|
else
|
|
rethrow(ex)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
"""
|
|
@epochs N body
|
|
|
|
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
|
training in a REPL.
|
|
|
|
```julia
|
|
julia> @epochs 2 println("hello")
|
|
INFO: Epoch 1
|
|
hello
|
|
INFO: Epoch 2
|
|
hello
|
|
```
|
|
"""
|
|
macro epochs(n, ex)
|
|
:(@progress for i = 1:$(esc(n))
|
|
@info "Epoch $i"
|
|
$(esc(ex))
|
|
end)
|
|
end
|