Flux.jl/dev/training/training/index.html
2020-03-04 04:13:50 +00:00

55 lines
15 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"/><meta name="viewport" content="width=device-width, initial-scale=1.0"/><title>Training · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview', {'page': location.pathname + location.search + location.hash});
</script><link href="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/fontawesome.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/solid.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/brands.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.11.1/katex.min.css" rel="stylesheet" type="text/css"/><script>documenterBaseURL="../.."</script><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js" data-main="../../assets/documenter.js"></script><script src="../../siteinfo.js"></script><script src="../../../versions.js"></script><link href="../../assets/flux.css" rel="stylesheet" type="text/css"/><link class="docs-theme-link" rel="stylesheet" type="text/css" href="../../assets/themes/documenter-dark.css" data-theme-name="documenter-dark"/><link class="docs-theme-link" rel="stylesheet" type="text/css" href="../../assets/themes/documenter-light.css" data-theme-name="documenter-light" data-theme-primary/><script src="../../assets/themeswap.js"></script></head><body><div id="documenter"><nav class="docs-sidebar"><div class="docs-package-name"><span class="docs-autofit">Flux</span></div><form class="docs-search" action="../../search/"><input class="docs-search-query" id="documenter-search-query" name="q" type="text" placeholder="Search docs"/></form><ul class="docs-menu"><li><a class="tocitem" href="../../">Home</a></li><li><span class="tocitem">Building Models</span><ul><li><a class="tocitem" href="../../models/basics/">Basics</a></li><li><a class="tocitem" href="../../models/recurrence/">Recurrence</a></li><li><a class="tocitem" href="../../models/regularisation/">Regularisation</a></li><li><a class="tocitem" href="../../models/layers/">Model Reference</a></li><li><a class="tocitem" href="../../models/advanced/">Advanced Model Building</a></li><li><a class="tocitem" href="../../models/nnlib/">NNlib</a></li></ul></li><li><span class="tocitem">Handling Data</span><ul><li><a class="tocitem" href="../../data/onehot/">One-Hot Encoding</a></li><li><a class="tocitem" href="../../data/dataloader/">DataLoader</a></li></ul></li><li><span class="tocitem">Training Models</span><ul><li><a class="tocitem" href="../optimisers/">Optimisers</a></li><li class="is-active"><a class="tocitem" href>Training</a><ul class="internal"><li><a class="tocitem" href="#Loss-Functions-1"><span>Loss Functions</span></a></li><li><a class="tocitem" href="#Model-parameters-1"><span>Model parameters</span></a></li><li><a class="tocitem" href="#Datasets-1"><span>Datasets</span></a></li><li><a class="tocitem" href="#Callbacks-1"><span>Callbacks</span></a></li><li><a class="tocitem" href="#Custom-Training-loops-1"><span>Custom Training loops</span></a></li></ul></li></ul></li><li><a class="tocitem" href="../../gpu/">GPU Support</a></li><li><a class="tocitem" href="../../saving/">Saving &amp; Loading</a></li><li><a class="tocitem" href="../../ecosystem/">The Julia Ecosystem</a></li><li><a class="tocitem" href="../../performance/">Performance Tips</a></li><li><a class="tocitem" href="../../community/">Community</a></li></ul><div class="docs-version-selector field has-addons"><div class="control"><span class="docs-label button is-static is-size-7">Version</span></div><div class="docs-selector control is-expanded"><div class="select is-fullwidth is-size-7"><select id="documenter-version-selector"></select></div></div></div></nav><div class="docs-main"><header class="docs-navbar"><nav class="breadcrumb"><ul class="is-hidden-mobile"><li><a class="is-disabled">Training Models</a></li><li class="is-active"><a href>Training</a></li></ul><ul class="is-hidden-tablet"><li class="is-active"><a href>Training</a></li></ul></nav><div class="docs-right"><a class="docs-edit-link" href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/training/training.md" title="Edit on GitHub"><span class="docs-icon fab"></span><span class="docs-label is-hidden-touch">Edit on GitHub</span></a><a class="docs-settings-button fas fa-cog" id="documenter-settings-button" href="#" title="Settings"></a><a class="docs-sidebar-button fa fa-bars is-hidden-desktop" id="documenter-sidebar-button" href="#"></a></div></header><article class="content" id="documenter-page"><h1 id="Training-1"><a class="docs-heading-anchor" href="#Training-1">Training</a><a class="docs-heading-anchor-permalink" href="#Training-1" title="Permalink"></a></h1><p>To actually train a model we need four things:</p><ul><li>A <em>objective function</em>, that evaluates how well a model is doing given some input data.</li><li>The trainable parameters of the model.</li><li>A collection of data points that will be provided to the objective function.</li><li>An <a href="../optimisers/">optimiser</a> that will update the model parameters appropriately.</li></ul><p>With these we can call <code>train!</code>:</p><article class="docstring"><header><a class="docstring-binding" id="Flux.Optimise.train!" href="#Flux.Optimise.train!"><code>Flux.Optimise.train!</code></a><span class="docstring-category">Function</span></header><section><div><pre><code class="language-julia">train!(loss, params, data, opt; cb)</code></pre><p>For each datapoint <code>d</code> in <code>data</code> computes the gradient of <code>loss(d...)</code> through backpropagation and calls the optimizer <code>opt</code>.</p><p>In case datapoints <code>d</code> are of numeric array type, assumes no splatting is needed and computes the gradient of <code>loss(d)</code>.</p><p>Takes a callback as keyword argument <code>cb</code>. For example, this will print &quot;training&quot; every 10 seconds:</p><p>train!(loss, params, data, opt, cb = throttle(() -&gt; println(&quot;training&quot;), 10))</p><p>The callback can call <code>Flux.stop()</code> to interrupt the training loop.</p><p>Multiple optimisers and callbacks can be passed to <code>opt</code> and <code>cb</code> as arrays.</p></div><a class="docs-sourcelink" target="_blank" href="https://github.com/FluxML/Flux.jl/blob/df3f904f7c34f095562693b6a9ca67047319dea0/src/optimise/train.jl#L58-L76">source</a></section></article><p>There are plenty of examples in the <a href="https://github.com/FluxML/model-zoo">model zoo</a>.</p><h2 id="Loss-Functions-1"><a class="docs-heading-anchor" href="#Loss-Functions-1">Loss Functions</a><a class="docs-heading-anchor-permalink" href="#Loss-Functions-1" title="Permalink"></a></h2><p>The objective function must return a number representing how far the model is from its target the <em>loss</em> of the model. The <code>loss</code> function that we defined in <a href="../../models/basics/">basics</a> will work as an objective. We can also define an objective in terms of some model:</p><pre><code class="language-julia">m = Chain(
Dense(784, 32, σ),
Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y)
ps = Flux.params(m)
# later
Flux.train!(loss, ps, data, opt)</code></pre><p>The objective will almost always be defined in terms of some <em>cost function</em> that measures the distance of the prediction <code>m(x)</code> from the target <code>y</code>. Flux has several of these built in, like <code>mse</code> for mean squared error or <code>crossentropy</code> for cross entropy loss, but you can calculate it however you want.</p><p>At first glance it may seem strange that the model that we want to train is not part of the input arguments of <code>Flux.train!</code> too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.</p><h2 id="Model-parameters-1"><a class="docs-heading-anchor" href="#Model-parameters-1">Model parameters</a><a class="docs-heading-anchor-permalink" href="#Model-parameters-1" title="Permalink"></a></h2><p>The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the <a href="../../models/basics/">basics</a> section it is explained how to create models with such parameters. The second argument of the function <code>Flux.train!</code> must be an object containing those parameters, which can be obtained from a model <code>m</code> as <code>params(m)</code>.</p><p>Such an object contains a reference to the model&#39;s parameters, not a copy, such that after their training, the model behaves according to their updated values.</p><p>Handling all the parameters on a layer by layer basis is explained in the <a href="../../models/basics/">Layer Helpers</a> section. Also, for freezing model parameters, see the <a href="../../models/advanced/">Advanced Usage Guide</a>.</p><h2 id="Datasets-1"><a class="docs-heading-anchor" href="#Datasets-1">Datasets</a><a class="docs-heading-anchor-permalink" href="#Datasets-1" title="Permalink"></a></h2><p>The <code>data</code> argument provides a collection of data to train with (usually a set of inputs <code>x</code> and target outputs <code>y</code>). For example, here&#39;s a dummy data set with only one data point:</p><pre><code class="language-julia">x = rand(784)
y = rand(10)
data = [(x, y)]</code></pre><p><code>Flux.train!</code> will call <code>loss(x, y)</code>, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:</p><pre><code class="language-julia">data = [(x, y), (x, y), (x, y)]
# Or equivalently
using IterTools: ncycle
data = ncycle([(x, y)], 3)</code></pre><p>It&#39;s common to load the <code>x</code>s and <code>y</code>s separately. In this case you can use <code>zip</code>:</p><pre><code class="language-julia">xs = [rand(784), rand(784), rand(784)]
ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)</code></pre><p>Training data can be conveniently partitioned for mini-batch training using the <a href="../../data/dataloader/#Flux.Data.DataLoader"><code>Flux.Data.DataLoader</code></a> type:</p><pre><code class="language-julia">X = rand(28, 28, 60000)
Y = rand(0:9, 60000)
data = DataLoader(X, Y, batchsize=128) </code></pre><p>Note that, by default, <code>train!</code> only loops over the data once (a single &quot;epoch&quot;). A convenient way to run multiple epochs from the REPL is provided by <code>@epochs</code>.</p><pre><code class="language-julia">julia&gt; using Flux: @epochs
julia&gt; @epochs 2 println(&quot;hello&quot;)
INFO: Epoch 1
hello
INFO: Epoch 2
hello
julia&gt; @epochs 2 Flux.train!(...)
# Train for two epochs</code></pre><h2 id="Callbacks-1"><a class="docs-heading-anchor" href="#Callbacks-1">Callbacks</a><a class="docs-heading-anchor-permalink" href="#Callbacks-1" title="Permalink"></a></h2><p><code>train!</code> takes an additional argument, <code>cb</code>, that&#39;s used for callbacks so that you can observe the training process. For example:</p><pre><code class="language-julia">train!(objective, ps, data, opt, cb = () -&gt; println(&quot;training&quot;))</code></pre><p>Callbacks are called for every batch of training data. You can slow this down using <code>Flux.throttle(f, timeout)</code> which prevents <code>f</code> from being called more than once every <code>timeout</code> seconds.</p><p>A more typical callback might look like this:</p><pre><code class="language-julia">test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
Flux.train!(objective, ps, data, opt,
cb = throttle(evalcb, 5))</code></pre><p>Calling <code>Flux.stop()</code> in a callback will exit the training loop early.</p><pre><code class="language-julia">cb = function ()
accuracy() &gt; 0.9 &amp;&amp; Flux.stop()
end</code></pre><h2 id="Custom-Training-loops-1"><a class="docs-heading-anchor" href="#Custom-Training-loops-1">Custom Training loops</a><a class="docs-heading-anchor-permalink" href="#Custom-Training-loops-1" title="Permalink"></a></h2><p>The <code>Flux.train!</code> function can be very convenient, especially for simple problems. Its also very flexible with the use of callbacks. But for some problems its much cleaner to write your own custom training loop. An example follows that works similar to the default <code>Flux.train</code> but with no callbacks. You don&#39;t need callbacks if you just code the calls to your functions directly into the loop. E.g. in the places marked with comments.</p><pre><code class="language-julia">function my_custom_train!(loss, ps, data, opt)
ps = Params(ps)
for d in data
gs = gradient(ps) do
training_loss = loss(d...)
# Insert what ever code you want here that needs Training loss, e.g. logging
return training_loss
end
# insert what ever code you want here that needs gradient
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
update!(opt, ps, gs)
# Here you might like to check validation set accuracy, and break out to do early stopping
end
end</code></pre><p>You could simplify this further, for example by hard-coding in the loss function.</p></article><nav class="docs-footer"><a class="docs-footer-prevpage" href="../optimisers/">« Optimisers</a><a class="docs-footer-nextpage" href="../../gpu/">GPU Support »</a></nav></div><div class="modal" id="documenter-settings"><div class="modal-background"></div><div class="modal-card"><header class="modal-card-head"><p class="modal-card-title">Settings</p><button class="delete"></button></header><section class="modal-card-body"><p><label class="label">Theme</label><div class="select"><select id="documenter-themepicker"><option value="documenter-light">documenter-light</option><option value="documenter-dark">documenter-dark</option></select></div></p><hr/><p>This document was generated with <a href="https://github.com/JuliaDocs/Documenter.jl">Documenter.jl</a> on <span class="colophon-date" title="Wednesday 4 March 2020 04:13">Wednesday 4 March 2020</span>. Using Julia version 1.3.1.</p></section><footer class="modal-card-foot"></footer></div></div></div></body></html>