</script><linkhref="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css"rel="stylesheet"type="text/css"/><linkhref="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css"rel="stylesheet"type="text/css"/><script>documenterBaseURL=".."</script><scriptsrc="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js"data-main="../assets/documenter.js"></script><scriptsrc="../siteinfo.js"></script><scriptsrc="../../versions.js"></script><linkhref="../assets/documenter.css"rel="stylesheet"type="text/css"/><linkhref="../../flux.css"rel="stylesheet"type="text/css"/></head><body><navclass="toc"><h1>Flux</h1><selectid="version-selector"onChange="window.location.href=this.value"style="visibility: hidden"></select><formclass="search"id="search-form"action="../search.html"><inputid="search-query"name="q"type="text"placeholder="Search docs"/></form><ul><li><aclass="toctext"href="../index.html">Home</a></li><li><spanclass="toctext">Building Models</span><ul><li><aclass="toctext"href="../models/basics.html">Basics</a></li><li><aclass="toctext"href="../models/recurrence.html">Recurrence</a></li><li><aclass="toctext"href="../models/layers.html">Layer Reference</a></li></ul></li><li><spanclass="toctext">Training Models</span><ul><li><aclass="toctext"href="optimisers.html">Optimisers</a></li><liclass="current"><aclass="toctext"href="training.html">Training</a><ulclass="internal"><li><aclass="toctext"href="#Loss-Functions-1">Loss Functions</a></li><li><aclass="toctext"href="#Datasets-1">Datasets</a></li><li><aclass="toctext"href="#Callbacks-1">Callbacks</a></li></ul></li></ul></li><li><aclass="toctext"href="../data/onehot.html">One-Hot Encoding</a></li><li><aclass="toctext"href="../gpu.html">GPU Support</a></li><li><aclass="toctext"href="../community.html">Community</a></li></ul></nav><articleid="docs"><header><nav><ul><li>Training Models</li><li><ahref="training.html">Training</a></li></ul><aclass="edit-page"href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/training/training.md"><spanclass="fa"></span> Edit on GitHub</a></nav><hr/><divid="topbar"><span>Training</span><aclass="fa fa-bars"href="#"></a></div></header><h1><aclass="nav-anchor"id="Training-1"href="#Training-1">Training</a></h1><p>To actually train a model we need three things:</p><ul><li><p>A <em>model loss function</em>, that evaluates how well a model is doing given some input data.</p></li><li><p>A collection of data points that will be provided to the loss function.</p></li><li><p>An <ahref="optimisers.html">optimiser</a> that will update the model parameters appropriately.</p></li></ul><p>With these we can call <code>Flux.train!</code>:</p><pre><codeclass="language-julia">Flux.train!(modelLoss, data, opt)</code></pre><p>There are plenty of examples in the <ahref="https://github.com/FluxML/model-zoo">model zoo</a>.</p><h2><aclass="nav-anchor"id="Loss-Functions-1"href="#Loss-Functions-1">Loss Functions</a></h2><p>The <code>loss</code> that we defined in <ahref="../models/basics.html">basics</a> is completely valid for training. We can also define a loss in terms of some model:</p><pre><codeclass="language-julia">m = Chain(
Flux.train!(loss, data, opt)</code></pre><p>The loss will almost always be defined in terms of some <em>cost function</em> that measures the distance of the prediction <code>m(x)</code> from the target <code>y</code>. Flux has several of these built in, like <code>mse</code> for mean squared error or <code>crossentropy</code> for cross entropy loss, but you can calculate it however you want.</p><h2><aclass="nav-anchor"id="Datasets-1"href="#Datasets-1">Datasets</a></h2><p>The <code>data</code> argument provides a collection of data to train with (usually a set of inputs <code>x</code> and target outputs <code>y</code>). For example, here's a dummy data set with only one data point:</p><pre><codeclass="language-julia">x = rand(784)
data = [(x, y)]</code></pre><p><code>Flux.train!</code> will call <code>loss(x, y)</code>, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:</p><pre><codeclass="language-julia">data = [(x, y), (x, y), (x, y)]
# Or equivalently
data = Iterators.repeated((x, y), 3)</code></pre><p>It's common to load the <code>x</code>s and <code>y</code>s separately. In this case you can use <code>zip</code>:</p><pre><codeclass="language-julia">xs = [rand(784), rand(784), rand(784)]
ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)</code></pre><h2><aclass="nav-anchor"id="Callbacks-1"href="#Callbacks-1">Callbacks</a></h2><p><code>train!</code> takes an additional argument, <code>cb</code>, that's used for callbacks so that you can observe the training process. For example:</p><pre><codeclass="language-julia">train!(loss, data, opt, cb = () -> println("training"))</code></pre><p>Callbacks are called for every batch of training data. You can slow this down using <code>Flux.throttle(f, timeout)</code> which prevents <code>f</code> from being called more than once every <code>timeout</code> seconds.</p><p>A more typical callback might look like this:</p><pre><codeclass="language-julia">test_x, test_y = # ... create single batch of test data ...