36 lines
7.4 KiB
HTML
36 lines
7.4 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="en"><head><meta charset="UTF-8"/><meta name="viewport" content="width=device-width, initial-scale=1.0"/><title>Training · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
|
||
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
|
||
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
|
||
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
|
||
|
||
ga('create', 'UA-36890222-9', 'auto');
|
||
ga('send', 'pageview');
|
||
</script><link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/><link href="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css" rel="stylesheet" type="text/css"/><script>documenterBaseURL=".."</script><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script><script src="../siteinfo.js"></script><script src="../../versions.js"></script><link href="../assets/documenter.css" rel="stylesheet" type="text/css"/><link href="../../flux.css" rel="stylesheet" type="text/css"/></head><body><nav class="toc"><h1>Flux</h1><select id="version-selector" onChange="window.location.href=this.value" style="visibility: hidden"></select><form class="search" id="search-form" action="../search.html"><input id="search-query" name="q" type="text" placeholder="Search docs"/></form><ul><li><a class="toctext" href="../index.html">Home</a></li><li><span class="toctext">Building Models</span><ul><li><a class="toctext" href="../models/basics.html">Basics</a></li><li><a class="toctext" href="../models/recurrence.html">Recurrence</a></li><li><a class="toctext" href="../models/regularisation.html">Regularisation</a></li><li><a class="toctext" href="../models/layers.html">Model Reference</a></li></ul></li><li><span class="toctext">Training Models</span><ul><li><a class="toctext" href="optimisers.html">Optimisers</a></li><li class="current"><a class="toctext" href="training.html">Training</a><ul class="internal"><li><a class="toctext" href="#Loss-Functions-1">Loss Functions</a></li><li><a class="toctext" href="#Datasets-1">Datasets</a></li><li><a class="toctext" href="#Callbacks-1">Callbacks</a></li></ul></li></ul></li><li><a class="toctext" href="../data/onehot.html">One-Hot Encoding</a></li><li><a class="toctext" href="../gpu.html">GPU Support</a></li><li><a class="toctext" href="../saving.html">Saving & Loading</a></li><li><span class="toctext">Internals</span><ul><li><a class="toctext" href="../internals/tracker.html">Backpropagation</a></li></ul></li><li><a class="toctext" href="../community.html">Community</a></li></ul></nav><article id="docs"><header><nav><ul><li>Training Models</li><li><a href="training.html">Training</a></li></ul><a class="edit-page" href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/training/training.md"><span class="fa"></span> Edit on GitHub</a></nav><hr/><div id="topbar"><span>Training</span><a class="fa fa-bars" href="#"></a></div></header><h1><a class="nav-anchor" id="Training-1" href="#Training-1">Training</a></h1><p>To actually train a model we need three things:</p><ul><li>A <em>objective function</em>, that evaluates how well a model is doing given some input data.</li><li>A collection of data points that will be provided to the objective function.</li><li>An <a href="optimisers.html">optimiser</a> that will update the model parameters appropriately.</li></ul><p>With these we can call <code>Flux.train!</code>:</p><pre><code class="language-julia">Flux.train!(objective, data, opt)</code></pre><p>There are plenty of examples in the <a href="https://github.com/FluxML/model-zoo">model zoo</a>.</p><h2><a class="nav-anchor" id="Loss-Functions-1" href="#Loss-Functions-1">Loss Functions</a></h2><p>The objective function must return a number representing how far the model is from its target – the <em>loss</em> of the model. The <code>loss</code> function that we defined in <a href="../models/basics.html">basics</a> will work as an objective. We can also define an objective in terms of some model:</p><pre><code class="language-julia">m = Chain(
|
||
Dense(784, 32, σ),
|
||
Dense(32, 10), softmax)
|
||
|
||
loss(x, y) = Flux.mse(m(x), y)
|
||
|
||
# later
|
||
Flux.train!(loss, data, opt)</code></pre><p>The objective will almost always be defined in terms of some <em>cost function</em> that measures the distance of the prediction <code>m(x)</code> from the target <code>y</code>. Flux has several of these built in, like <code>mse</code> for mean squared error or <code>crossentropy</code> for cross entropy loss, but you can calculate it however you want.</p><h2><a class="nav-anchor" id="Datasets-1" href="#Datasets-1">Datasets</a></h2><p>The <code>data</code> argument provides a collection of data to train with (usually a set of inputs <code>x</code> and target outputs <code>y</code>). For example, here's a dummy data set with only one data point:</p><pre><code class="language-julia">x = rand(784)
|
||
y = rand(10)
|
||
data = [(x, y)]</code></pre><p><code>Flux.train!</code> will call <code>loss(x, y)</code>, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:</p><pre><code class="language-julia">data = [(x, y), (x, y), (x, y)]
|
||
# Or equivalently
|
||
data = Iterators.repeated((x, y), 3)</code></pre><p>It's common to load the <code>x</code>s and <code>y</code>s separately. In this case you can use <code>zip</code>:</p><pre><code class="language-julia">xs = [rand(784), rand(784), rand(784)]
|
||
ys = [rand( 10), rand( 10), rand( 10)]
|
||
data = zip(xs, ys)</code></pre><p>Note that, by default, <code>train!</code> only loops over the data once (a single "epoch"). A convenient way to run multiple epochs from the REPL is provided by <code>@epochs</code>.</p><pre><code class="language-julia">julia> using Flux: @epochs
|
||
|
||
julia> @epochs 2 println("hello")
|
||
INFO: Epoch 1
|
||
hello
|
||
INFO: Epoch 2
|
||
hello
|
||
|
||
julia> @epochs 2 Flux.train!(...)
|
||
# Train for two epochs</code></pre><h2><a class="nav-anchor" id="Callbacks-1" href="#Callbacks-1">Callbacks</a></h2><p><code>train!</code> takes an additional argument, <code>cb</code>, that's used for callbacks so that you can observe the training process. For example:</p><pre><code class="language-julia">train!(objective, data, opt, cb = () -> println("training"))</code></pre><p>Callbacks are called for every batch of training data. You can slow this down using <code>Flux.throttle(f, timeout)</code> which prevents <code>f</code> from being called more than once every <code>timeout</code> seconds.</p><p>A more typical callback might look like this:</p><pre><code class="language-julia">test_x, test_y = # ... create single batch of test data ...
|
||
evalcb() = @show(loss(test_x, test_y))
|
||
|
||
Flux.train!(objective, data, opt,
|
||
cb = throttle(evalcb, 5))</code></pre><footer><hr/><a class="previous" href="optimisers.html"><span class="direction">Previous</span><span class="title">Optimisers</span></a><a class="next" href="../data/onehot.html"><span class="direction">Next</span><span class="title">One-Hot Encoding</span></a></footer></article></body></html>
|