Flux.jl/release-0.3/training/training.html
2017-09-28 10:32:09 +00:00

19 lines
5.5 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"/><meta name="viewport" content="width=device-width, initial-scale=1.0"/><title>Training · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
</script><link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/><link href="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css" rel="stylesheet" type="text/css"/><script>documenterBaseURL=".."</script><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script><script src="../siteinfo.js"></script><script src="../../versions.js"></script><link href="../assets/documenter.css" rel="stylesheet" type="text/css"/><link href="../../flux.css" rel="stylesheet" type="text/css"/></head><body><nav class="toc"><h1>Flux</h1><select id="version-selector" onChange="window.location.href=this.value" style="visibility: hidden"></select><form class="search" action="../search.html"><input id="search-query" name="q" type="text" placeholder="Search docs"/></form><ul><li><a class="toctext" href="../index.html">Home</a></li><li><span class="toctext">Building Models</span><ul><li><a class="toctext" href="../models/basics.html">Basics</a></li><li><a class="toctext" href="../models/recurrence.html">Recurrence</a></li><li><a class="toctext" href="../models/layers.html">Layer Reference</a></li></ul></li><li><span class="toctext">Training Models</span><ul><li><a class="toctext" href="optimisers.html">Optimisers</a></li><li class="current"><a class="toctext" href="training.html">Training</a><ul class="internal"><li><a class="toctext" href="#Loss-Functions-1">Loss Functions</a></li><li><a class="toctext" href="#Callbacks-1">Callbacks</a></li></ul></li></ul></li><li><a class="toctext" href="../data/onehot.html">One-Hot Encoding</a></li><li><a class="toctext" href="../gpu.html">GPU Support</a></li><li><a class="toctext" href="../contributing.html">Contributing &amp; Help</a></li></ul></nav><article id="docs"><header><nav><ul><li>Training Models</li><li><a href="training.html">Training</a></li></ul><a class="edit-page" href="https://github.com/FluxML/Flux.jl/tree/d3419c943bd8c8c3d06ffa8ea5618625b3b14613/docs/src/training/training.md"><span class="fa"></span> Edit on GitHub</a></nav><hr/><div id="topbar"><span>Training</span><a class="fa fa-bars" href="#"></a></div></header><h1><a class="nav-anchor" id="Training-1" href="#Training-1">Training</a></h1><p>To actually train a model we need three things:</p><ul><li><p>A <em>model loss function</em>, that evaluates how well a model is doing given some input data.</p></li><li><p>A collection of data points that will be provided to the loss function.</p></li><li><p>An <a href="optimisers.html">optimiser</a> that will update the model parameters appropriately.</p></li></ul><p>With these we can call <code>Flux.train!</code>:</p><pre><code class="language-julia">Flux.train!(modelLoss, data, opt)</code></pre><p>There are plenty of examples in the <a href="https://github.com/FluxML/model-zoo">model zoo</a>.</p><h2><a class="nav-anchor" id="Loss-Functions-1" href="#Loss-Functions-1">Loss Functions</a></h2><p>The <code>loss</code> that we defined in <a href="../models/basics.html">basics</a> is completely valid for training. We can also define a loss in terms of some model:</p><pre><code class="language-julia">m = Chain(
Dense(784, 32, σ),
Dense(32, 10), softmax)
# Model loss function
loss(x, y) = Flux.mse(m(x), y)</code></pre><p>The loss will almost always be defined in terms of some <em>cost function</em> that measures the distance of the prediction <code>m(x)</code> from the target <code>y</code>. Flux has several of these built in, like <code>mse</code> for mean squared error or <code>logloss</code> for cross entropy loss, but you can calculate it however you want.</p><h2><a class="nav-anchor" id="Callbacks-1" href="#Callbacks-1">Callbacks</a></h2><p><code>train!</code> takes an additional argument, <code>cb</code>, that&#39;s used for callbacks so that you can observe the training process. For example:</p><pre><code class="language-julia">train!(loss, data, opt, cb = () -&gt; println(&quot;training&quot;))</code></pre><p>Callbacks are called for every batch of training data. You can slow this down using <code>Flux.throttle(f, timeout)</code> which prevents <code>f</code> from being called more than once every <code>timeout</code> seconds.</p><p>A more typical callback might look like this:</p><pre><code class="language-julia">test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
Flux.train!(loss, data, opt,
cb = throttle(evalcb, 5))</code></pre><footer><hr/><a class="previous" href="optimisers.html"><span class="direction">Previous</span><span class="title">Optimisers</span></a><a class="next" href="../data/onehot.html"><span class="direction">Next</span><span class="title">One-Hot Encoding</span></a></footer></article></body></html>