Flux.jl/src/optimise/train.jl

125 lines
2.5 KiB
Julia
Raw Normal View History

2017-09-03 06:44:32 +00:00
using Juno
2019-03-08 12:06:09 +00:00
import Zygote: Params, gradient
2018-05-31 19:29:59 +00:00
2020-02-26 19:27:39 +00:00
"""
2020-04-14 04:12:06 +00:00
update!(x, )
2020-02-26 19:27:39 +00:00
Update the array `x` according to `x .-= x̄`.
"""
2019-04-05 16:17:50 +00:00
function update!(x::AbstractArray, )
2020-02-26 19:27:39 +00:00
x .-=
2019-04-05 16:17:50 +00:00
end
"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
"""
2019-01-28 14:10:09 +00:00
function update!(opt, x, )
2019-08-19 14:44:51 +00:00
x .-= apply!(opt, x, )
2019-01-28 14:10:09 +00:00
end
2019-02-28 14:58:42 +00:00
function update!(opt, xs::Params, gs)
for x in xs
2019-08-19 14:44:51 +00:00
gs[x] == nothing && continue
2019-02-28 14:58:42 +00:00
update!(opt, x, gs[x])
end
end
2018-05-31 19:29:59 +00:00
# Callback niceties
2018-11-08 13:14:57 +00:00
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
2017-09-07 04:27:16 +00:00
2018-08-20 08:38:23 +00:00
struct StopException <: Exception end
2019-08-19 14:44:51 +00:00
2018-08-20 08:50:33 +00:00
"""
stop()
2018-08-21 18:59:07 +00:00
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to stop and exit.
2018-08-20 08:50:33 +00:00
# Examples
2018-08-20 08:50:33 +00:00
```julia
cb = function ()
2018-08-28 09:32:47 +00:00
accuracy() > 0.9 && Flux.stop()
2018-08-20 08:50:33 +00:00
end
```
"""
2018-08-20 08:32:09 +00:00
function stop()
2018-08-20 08:18:28 +00:00
throw(StopException())
end
2018-08-20 08:13:08 +00:00
2017-10-11 11:26:40 +00:00
"""
2019-01-10 11:05:21 +00:00
train!(loss, params, data, opt; cb)
2017-10-11 11:26:40 +00:00
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
backpropagation and call the optimizer `opt`.
2017-12-13 18:24:56 +00:00
In case datapoints `d` are of numeric array type, assume no splatting is needed
and compute the gradient of `loss(d)`.
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
2017-12-13 18:24:56 +00:00
train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
2017-12-13 18:24:56 +00:00
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
2017-10-11 11:26:40 +00:00
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
2017-10-11 11:26:40 +00:00
"""
2018-10-27 13:56:42 +00:00
function train!(loss, ps, data, opt; cb = () -> ())
2019-02-28 14:58:42 +00:00
ps = Params(ps)
2018-08-19 09:47:07 +00:00
cb = runall(cb)
2017-10-11 11:26:40 +00:00
@progress for d in data
2018-08-19 09:47:07 +00:00
try
2020-02-27 19:49:05 +00:00
if d isa AbstractArray{<:Number}
gs = gradient(ps) do
loss(d)
end
else
gs = gradient(ps) do
loss(d...)
end
2019-02-28 14:58:42 +00:00
end
update!(opt, ps, gs)
2019-06-19 19:07:54 +00:00
cb()
2018-08-19 09:47:07 +00:00
catch ex
2018-08-19 12:08:00 +00:00
if ex isa StopException
2018-08-19 09:47:07 +00:00
break
else
rethrow(ex)
end
2018-08-19 09:49:45 +00:00
end
2017-08-24 10:42:29 +00:00
end
end
2018-03-05 22:56:22 +00:00
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
# Examples
```jldoctest
julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
2018-03-05 22:56:22 +00:00
hello
[ Info: Epoch 2
2018-03-05 22:56:22 +00:00
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
2018-08-11 13:42:33 +00:00
@info "Epoch $i"
2018-03-05 22:56:22 +00:00
$(esc(ex))
end)
2018-08-28 09:54:50 +00:00
end