2017-09-03 06:44:32 +00:00
|
|
|
using Juno
|
2018-05-31 19:29:59 +00:00
|
|
|
using Flux.Tracker: data, grad, back!
|
2018-09-14 15:02:56 +00:00
|
|
|
import Base.depwarn
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
|
|
function update!(opt, xs)
|
|
|
|
for x in xs
|
2018-10-27 13:56:42 +00:00
|
|
|
Δ = update!(opt, x.data, x.grad)
|
|
|
|
x.data .-= Δ
|
2018-05-31 19:29:59 +00:00
|
|
|
Δ .= 0
|
|
|
|
end
|
|
|
|
end
|
2017-08-31 16:36:18 +00:00
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
# Callback niceties
|
2018-11-08 13:14:57 +00:00
|
|
|
call(f, xs...) = f(xs...)
|
2017-11-04 12:27:32 +00:00
|
|
|
runall(f) = f
|
|
|
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
2017-09-07 04:27:16 +00:00
|
|
|
|
2018-03-05 23:05:45 +00:00
|
|
|
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
|
|
|
# while training; this just cleans that up.
|
|
|
|
macro interrupts(ex)
|
|
|
|
:(try $(esc(ex))
|
|
|
|
catch e
|
|
|
|
e isa InterruptException || rethrow()
|
|
|
|
throw(e)
|
|
|
|
end)
|
|
|
|
end
|
|
|
|
|
2018-08-20 08:38:23 +00:00
|
|
|
struct StopException <: Exception end
|
2018-08-20 08:50:33 +00:00
|
|
|
"""
|
|
|
|
stop()
|
|
|
|
|
2018-08-21 18:59:07 +00:00
|
|
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
2018-08-20 08:50:33 +00:00
|
|
|
This would trigger the train loop to stop and exit.
|
|
|
|
|
|
|
|
```julia
|
|
|
|
# Example callback:
|
2018-08-20 08:13:08 +00:00
|
|
|
|
2018-08-20 08:50:33 +00:00
|
|
|
cb = function ()
|
2018-08-28 09:32:47 +00:00
|
|
|
accuracy() > 0.9 && Flux.stop()
|
2018-08-20 08:50:33 +00:00
|
|
|
end
|
|
|
|
```
|
|
|
|
"""
|
2018-08-20 08:32:09 +00:00
|
|
|
function stop()
|
2018-08-20 08:18:28 +00:00
|
|
|
throw(StopException())
|
|
|
|
end
|
2018-08-20 08:13:08 +00:00
|
|
|
|
2017-10-11 11:26:40 +00:00
|
|
|
"""
|
2018-10-11 04:37:16 +00:00
|
|
|
train!(model, loss, data, opt)
|
2017-10-11 11:26:40 +00:00
|
|
|
|
|
|
|
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
2017-12-13 18:24:56 +00:00
|
|
|
backpropagation and calls the optimizer `opt`.
|
|
|
|
|
|
|
|
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
|
|
|
every 10 seconds:
|
|
|
|
|
|
|
|
```julia
|
2018-10-11 04:37:16 +00:00
|
|
|
Flux.train!(model, loss, data, opt,
|
2017-12-13 18:24:56 +00:00
|
|
|
cb = throttle(() -> println("training"), 10))
|
|
|
|
```
|
|
|
|
|
|
|
|
The callback can return `:stop` to interrupt the training loop.
|
2017-10-11 11:26:40 +00:00
|
|
|
|
2017-11-04 12:27:32 +00:00
|
|
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
2017-10-11 11:26:40 +00:00
|
|
|
"""
|
2018-10-27 13:56:42 +00:00
|
|
|
function train!(loss, ps, data, opt; cb = () -> ())
|
2018-08-19 09:47:07 +00:00
|
|
|
cb = runall(cb)
|
2017-11-04 12:27:32 +00:00
|
|
|
opt = runall(opt)
|
2017-10-11 11:26:40 +00:00
|
|
|
@progress for d in data
|
2018-08-19 09:47:07 +00:00
|
|
|
try
|
2018-08-21 17:52:20 +00:00
|
|
|
l = loss(d...)
|
|
|
|
@interrupts back!(l)
|
2018-10-27 13:56:42 +00:00
|
|
|
update!(opt, ps)
|
2018-08-24 07:38:03 +00:00
|
|
|
if cb() == :stop
|
|
|
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
|
|
|
break
|
|
|
|
end
|
2018-08-19 09:47:07 +00:00
|
|
|
catch ex
|
2018-08-19 12:08:00 +00:00
|
|
|
if ex isa StopException
|
2018-08-19 09:47:07 +00:00
|
|
|
break
|
|
|
|
else
|
|
|
|
rethrow(ex)
|
|
|
|
end
|
2018-08-19 09:49:45 +00:00
|
|
|
end
|
2017-08-24 10:42:29 +00:00
|
|
|
end
|
|
|
|
end
|
2018-03-05 22:56:22 +00:00
|
|
|
|
|
|
|
"""
|
|
|
|
@epochs N body
|
|
|
|
|
|
|
|
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
|
|
|
training in a REPL.
|
|
|
|
|
|
|
|
```julia
|
|
|
|
julia> @epochs 2 println("hello")
|
|
|
|
INFO: Epoch 1
|
|
|
|
hello
|
|
|
|
INFO: Epoch 2
|
|
|
|
hello
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
macro epochs(n, ex)
|
|
|
|
:(@progress for i = 1:$(esc(n))
|
2018-08-11 13:42:33 +00:00
|
|
|
@info "Epoch $i"
|
2018-03-05 22:56:22 +00:00
|
|
|
$(esc(ex))
|
|
|
|
end)
|
2018-08-28 09:54:50 +00:00
|
|
|
end
|