Flux.jl/src/optimise/train.jl
2019-03-12 12:21:12 +00:00

105 lines
2.0 KiB
Julia

using Juno
import Zygote: Context, Params, _forward, gradient
# Training step
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
function step!(f, opt, x...)
cx = Context()
y, ∂f = _forward(cx, f, x...)
losscheck(y)
= ∂f(1)[1] # TODO update f
= Globals(cx)
update!(opt, nothing, )
return y
end
# Training loop
# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
```julia
# Example callback:
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
throw(StopException())
end
"""
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
loss(d...)
end
update!(opt, ps, gs)
catch ex
if ex isa StopException
break
else
rethrow(ex)
end
end
end
end
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
```julia
julia> @epochs 2 println("hello")
INFO: Epoch 1
hello
INFO: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
end)
end