2017-09-03 06:44:32 +00:00
|
|
|
using Juno
|
2019-03-08 12:06:09 +00:00
|
|
|
import Zygote: Params, gradient
|
2018-05-31 19:29:59 +00:00
|
|
|
|
2020-02-26 19:27:39 +00:00
|
|
|
|
|
|
|
|
2020-04-04 20:59:45 +00:00
|
|
|
"""
|
2020-04-14 04:12:06 +00:00
|
|
|
update!(x, x̄)
|
2020-04-04 20:59:45 +00:00
|
|
|
|
2020-02-26 19:27:39 +00:00
|
|
|
Update the array `x` according to `x .-= x̄`.
|
|
|
|
"""
|
2019-04-05 16:17:50 +00:00
|
|
|
function update!(x::AbstractArray, x̄)
|
2020-02-26 19:27:39 +00:00
|
|
|
x .-= x̄
|
2019-04-05 16:17:50 +00:00
|
|
|
end
|
|
|
|
|
2020-04-04 20:59:45 +00:00
|
|
|
"""
|
|
|
|
update!(opt, p, g)
|
|
|
|
update!(opt, ps::Params, gs)
|
|
|
|
|
|
|
|
Perform an update step of the parameters `ps` (or the single parameter `p`)
|
|
|
|
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
|
|
|
|
|
|
|
|
As a result, the parameters are mutated and the optimizer's internal state may change.
|
|
|
|
"""
|
2019-01-28 14:10:09 +00:00
|
|
|
function update!(opt, x, x̄)
|
2019-08-19 14:44:51 +00:00
|
|
|
x .-= apply!(opt, x, x̄)
|
2019-01-28 14:10:09 +00:00
|
|
|
end
|
|
|
|
|
2019-02-28 14:58:42 +00:00
|
|
|
function update!(opt, xs::Params, gs)
|
|
|
|
for x in xs
|
2019-08-19 14:44:51 +00:00
|
|
|
gs[x] == nothing && continue
|
2019-02-28 14:58:42 +00:00
|
|
|
update!(opt, x, gs[x])
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
# Callback niceties
|
2018-11-08 13:14:57 +00:00
|
|
|
call(f, xs...) = f(xs...)
|
2017-11-04 12:27:32 +00:00
|
|
|
runall(f) = f
|
|
|
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
2017-09-07 04:27:16 +00:00
|
|
|
|
2018-08-20 08:38:23 +00:00
|
|
|
struct StopException <: Exception end
|
2019-08-19 14:44:51 +00:00
|
|
|
|
2018-08-20 08:50:33 +00:00
|
|
|
"""
|
|
|
|
stop()
|
|
|
|
|
2018-08-21 18:59:07 +00:00
|
|
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
2020-04-04 20:59:45 +00:00
|
|
|
This will trigger the train loop to stop and exit.
|
2018-08-20 08:50:33 +00:00
|
|
|
|
Improve docstrings
Improvements like...
- fixing typos,
- removing trailing and double whitespaces,
- using `jldoctest` blocks where applicable,
- fixing, updating or correctly setting up existing doctests,
- improving consistency (for example, always use "# Examples" instead
of other variants),
- removing empty lines between docstrings and functions,
- instead of mentioning keywords, put them into the docstring,
- adding some missing but useful keywords,
- adding references (`@ref`),
- using LaTeX math where applicable, and
- linking papers.
Debatable stuff that is untouched:
- BE/AE s/z irregularities ("normalise" versus "normalize") since
most papers use the AE version while the Flux source code was
written with BE spelling.
- Names of normalization functions are capitalized
("Batch Normalization" instead of "batch normalization").
2019-08-31 09:39:28 +00:00
|
|
|
# Examples
|
2018-08-20 08:50:33 +00:00
|
|
|
```julia
|
|
|
|
cb = function ()
|
2018-08-28 09:32:47 +00:00
|
|
|
accuracy() > 0.9 && Flux.stop()
|
2018-08-20 08:50:33 +00:00
|
|
|
end
|
|
|
|
```
|
|
|
|
"""
|
2018-08-20 08:32:09 +00:00
|
|
|
function stop()
|
2018-08-20 08:18:28 +00:00
|
|
|
throw(StopException())
|
|
|
|
end
|
2018-08-20 08:13:08 +00:00
|
|
|
|
2017-10-11 11:26:40 +00:00
|
|
|
"""
|
2019-01-10 11:05:21 +00:00
|
|
|
train!(loss, params, data, opt; cb)
|
2017-10-11 11:26:40 +00:00
|
|
|
|
2020-04-04 20:59:45 +00:00
|
|
|
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
|
|
|
|
backpropagation and call the optimizer `opt`.
|
2017-12-13 18:24:56 +00:00
|
|
|
|
2020-04-04 20:59:45 +00:00
|
|
|
In case datapoints `d` are of numeric array type, assume no splatting is needed
|
|
|
|
and compute the gradient of `loss(d)`.
|
2020-02-26 12:48:27 +00:00
|
|
|
|
2020-04-04 20:59:45 +00:00
|
|
|
A callback is given with the keyword argument `cb`. For example, this will print
|
|
|
|
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
2017-12-13 18:24:56 +00:00
|
|
|
|
2020-02-26 12:48:27 +00:00
|
|
|
train!(loss, params, data, opt,
|
|
|
|
cb = throttle(() -> println("training"), 10))
|
2017-12-13 18:24:56 +00:00
|
|
|
|
2020-04-04 20:59:45 +00:00
|
|
|
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
|
2017-10-11 11:26:40 +00:00
|
|
|
|
2017-11-04 12:27:32 +00:00
|
|
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
2017-10-11 11:26:40 +00:00
|
|
|
"""
|
2018-10-27 13:56:42 +00:00
|
|
|
function train!(loss, ps, data, opt; cb = () -> ())
|
2019-02-28 14:58:42 +00:00
|
|
|
ps = Params(ps)
|
2018-08-19 09:47:07 +00:00
|
|
|
cb = runall(cb)
|
2017-10-11 11:26:40 +00:00
|
|
|
@progress for d in data
|
2018-08-19 09:47:07 +00:00
|
|
|
try
|
2020-02-27 19:49:05 +00:00
|
|
|
if d isa AbstractArray{<:Number}
|
2020-02-26 12:48:27 +00:00
|
|
|
gs = gradient(ps) do
|
|
|
|
loss(d)
|
|
|
|
end
|
|
|
|
else
|
|
|
|
gs = gradient(ps) do
|
|
|
|
loss(d...)
|
|
|
|
end
|
2019-02-28 14:58:42 +00:00
|
|
|
end
|
|
|
|
update!(opt, ps, gs)
|
2019-06-19 19:07:54 +00:00
|
|
|
cb()
|
2018-08-19 09:47:07 +00:00
|
|
|
catch ex
|
2018-08-19 12:08:00 +00:00
|
|
|
if ex isa StopException
|
2018-08-19 09:47:07 +00:00
|
|
|
break
|
|
|
|
else
|
|
|
|
rethrow(ex)
|
|
|
|
end
|
2018-08-19 09:49:45 +00:00
|
|
|
end
|
2017-08-24 10:42:29 +00:00
|
|
|
end
|
|
|
|
end
|
2018-03-05 22:56:22 +00:00
|
|
|
|
|
|
|
"""
|
|
|
|
@epochs N body
|
|
|
|
|
|
|
|
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
|
|
|
training in a REPL.
|
|
|
|
|
Improve docstrings
Improvements like...
- fixing typos,
- removing trailing and double whitespaces,
- using `jldoctest` blocks where applicable,
- fixing, updating or correctly setting up existing doctests,
- improving consistency (for example, always use "# Examples" instead
of other variants),
- removing empty lines between docstrings and functions,
- instead of mentioning keywords, put them into the docstring,
- adding some missing but useful keywords,
- adding references (`@ref`),
- using LaTeX math where applicable, and
- linking papers.
Debatable stuff that is untouched:
- BE/AE s/z irregularities ("normalise" versus "normalize") since
most papers use the AE version while the Flux source code was
written with BE spelling.
- Names of normalization functions are capitalized
("Batch Normalization" instead of "batch normalization").
2019-08-31 09:39:28 +00:00
|
|
|
# Examples
|
|
|
|
```jldoctest
|
|
|
|
julia> Flux.@epochs 2 println("hello")
|
|
|
|
[ Info: Epoch 1
|
2018-03-05 22:56:22 +00:00
|
|
|
hello
|
Improve docstrings
Improvements like...
- fixing typos,
- removing trailing and double whitespaces,
- using `jldoctest` blocks where applicable,
- fixing, updating or correctly setting up existing doctests,
- improving consistency (for example, always use "# Examples" instead
of other variants),
- removing empty lines between docstrings and functions,
- instead of mentioning keywords, put them into the docstring,
- adding some missing but useful keywords,
- adding references (`@ref`),
- using LaTeX math where applicable, and
- linking papers.
Debatable stuff that is untouched:
- BE/AE s/z irregularities ("normalise" versus "normalize") since
most papers use the AE version while the Flux source code was
written with BE spelling.
- Names of normalization functions are capitalized
("Batch Normalization" instead of "batch normalization").
2019-08-31 09:39:28 +00:00
|
|
|
[ Info: Epoch 2
|
2018-03-05 22:56:22 +00:00
|
|
|
hello
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
macro epochs(n, ex)
|
|
|
|
:(@progress for i = 1:$(esc(n))
|
2018-08-11 13:42:33 +00:00
|
|
|
@info "Epoch $i"
|
2018-03-05 22:56:22 +00:00
|
|
|
$(esc(ex))
|
|
|
|
end)
|
2018-08-28 09:54:50 +00:00
|
|
|
end
|