Flux.jl/src/optimise/train.jl

111 lines
2.2 KiB
Julia
Raw Normal View History

2017-09-03 06:44:32 +00:00
using Juno
2019-01-28 14:10:09 +00:00
import Flux.Tracker: data, grad, back!, update!
2018-09-14 15:02:56 +00:00
import Base.depwarn
2018-05-31 19:29:59 +00:00
2019-01-28 14:10:09 +00:00
function update!(opt, x, )
update!(x, apply!(opt, x, copy(data())))
end
2019-01-28 13:59:23 +00:00
function _update_params!(opt, xs)
2018-05-31 19:29:59 +00:00
for x in xs
2019-01-28 13:59:23 +00:00
Δ = apply!(opt, x.data, x.grad)
2018-10-27 13:56:42 +00:00
x.data .-= Δ
2018-05-31 19:29:59 +00:00
Δ .= 0
end
end
2017-08-31 16:36:18 +00:00
2018-05-31 19:29:59 +00:00
# Callback niceties
2018-11-08 13:14:57 +00:00
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
2017-09-07 04:27:16 +00:00
2018-03-05 23:05:45 +00:00
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
2018-08-20 08:38:23 +00:00
struct StopException <: Exception end
2018-08-20 08:50:33 +00:00
"""
stop()
2018-08-21 18:59:07 +00:00
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
2018-08-20 08:50:33 +00:00
This would trigger the train loop to stop and exit.
```julia
# Example callback:
2018-08-20 08:13:08 +00:00
2018-08-20 08:50:33 +00:00
cb = function ()
2018-08-28 09:32:47 +00:00
accuracy() > 0.9 && Flux.stop()
2018-08-20 08:50:33 +00:00
end
```
"""
2018-08-20 08:32:09 +00:00
function stop()
2018-08-20 08:18:28 +00:00
throw(StopException())
end
2018-08-20 08:13:08 +00:00
2017-10-11 11:26:40 +00:00
"""
2019-01-10 11:05:21 +00:00
train!(loss, params, data, opt; cb)
2017-10-11 11:26:40 +00:00
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
2017-12-13 18:24:56 +00:00
backpropagation and calls the optimizer `opt`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
2019-01-07 19:02:55 +00:00
Flux.train!(loss, params, data, opt,
2017-12-13 18:24:56 +00:00
cb = throttle(() -> println("training"), 10))
```
2019-01-10 11:04:07 +00:00
The callback can call `Flux.stop()` to interrupt the training loop.
2017-10-11 11:26:40 +00:00
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
2017-10-11 11:26:40 +00:00
"""
2018-10-27 13:56:42 +00:00
function train!(loss, ps, data, opt; cb = () -> ())
2018-08-19 09:47:07 +00:00
cb = runall(cb)
opt = runall(opt)
2017-10-11 11:26:40 +00:00
@progress for d in data
2018-08-19 09:47:07 +00:00
try
2018-08-21 17:52:20 +00:00
l = loss(d...)
@interrupts back!(l)
2019-01-28 13:59:23 +00:00
_update_params!(opt, ps)
2018-08-24 07:38:03 +00:00
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
end
2018-08-19 09:47:07 +00:00
catch ex
2018-08-19 12:08:00 +00:00
if ex isa StopException
2018-08-19 09:47:07 +00:00
break
else
rethrow(ex)
end
2018-08-19 09:49:45 +00:00
end
2017-08-24 10:42:29 +00:00
end
end
2018-03-05 22:56:22 +00:00
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
```julia
julia> @epochs 2 println("hello")
INFO: Epoch 1
hello
INFO: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
2018-08-11 13:42:33 +00:00
@info "Epoch $i"
2018-03-05 22:56:22 +00:00
$(esc(ex))
end)
2018-08-28 09:54:50 +00:00
end