</script><linkhref="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css"rel="stylesheet"type="text/css"/><linkhref="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css"rel="stylesheet"type="text/css"/><script>documenterBaseURL=".."</script><scriptsrc="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js"data-main="../assets/documenter.js"></script><scriptsrc="../siteinfo.js"></script><scriptsrc="../../versions.js"></script><linkhref="../assets/documenter.css"rel="stylesheet"type="text/css"/><linkhref="../../flux.css"rel="stylesheet"type="text/css"/></head><body><navclass="toc"><h1>Flux</h1><selectid="version-selector"onChange="window.location.href=this.value"style="visibility: hidden"></select><formclass="search"id="search-form"action="../search.html"><inputid="search-query"name="q"type="text"placeholder="Search docs"/></form><ul><li><aclass="toctext"href="../index.html">Home</a></li><li><spanclass="toctext">Building Models</span><ul><li><aclass="toctext"href="../models/basics.html">Basics</a></li><li><aclass="toctext"href="../models/recurrence.html">Recurrence</a></li><li><aclass="toctext"href="../models/regularisation.html">Regularisation</a></li><li><aclass="toctext"href="../models/layers.html">Model Reference</a></li></ul></li><li><spanclass="toctext">Training Models</span><ul><liclass="current"><aclass="toctext"href="optimisers.html">Optimisers</a><ulclass="internal"><li><aclass="toctext"href="#Optimiser-Reference-1">Optimiser Reference</a></li></ul></li><li><aclass="toctext"href="training.html">Training</a></li></ul></li><li><aclass="toctext"href="../data/onehot.html">One-Hot Encoding</a></li><li><aclass="toctext"href="../gpu.html">GPU Support</a></li><li><aclass="toctext"href="../saving.html">Saving & Loading</a></li><li><spanclass="toctext">Internals</span><ul><li><aclass="toctext"href="../internals/tracker.html">Backpropagation</a></li></ul></li><li><aclass="toctext"href="../community.html">Community</a></li></ul></nav><articleid="docs"><header><nav><ul><li>Training Models</li><li><ahref="optimisers.html">Optimisers</a></li></ul><aclass="edit-page"href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/training/optimisers.md"><spanclass="fa"></span> Edit on GitHub</a></nav><hr/><divid="topbar"><span>Optimisers</span><aclass="fa fa-bars"href="#"></a></div></header><h1><aclass="nav-anchor"id="Optimisers-1"href="#Optimisers-1">Optimisers</a></h1><p>Consider a <ahref="../models/basics.html">simple linear regression</a>. We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters <code>W</code> and <code>b</code>.</p><pre><codeclass="language-julia">using Flux.Tracker
W = param(rand(2, 5))
b = param(rand(2))
predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3
params = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params)</code></pre><p>We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:</p><pre><codeclass="language-julia">using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate
for p in (W, b)
update!(p, -η * grads[p])
end
end</code></pre><p>If we call <code>sgd</code>, the parameters <code>W</code> and <code>b</code> will change and our loss should go down.</p><p>There are two pieces here: one is that we need a list of trainable parameters for the model (<code>[W, b]</code> in this case), and the other is the update step. In this case the update is simply gradient descent (<code>x .-= η .* Δ</code>), but we might choose to do something more advanced, like adding momentum.</p><p>In this case, getting the variables is trivial, but you can imagine it'd be more of a pain with some complex stack of layers.</p><pre><codeclass="language-julia">m = Chain(
Dense(10, 5, σ),
Dense(5, 2), softmax)</code></pre><p>Instead of having to write <code>[m[1].W, m[1].b, ...]</code>, Flux provides a params function <code>params(m)</code> that returns a list of all parameters in the model for you.</p><p>For the update step, there's nothing whatsoever wrong with writing the loop above – it'll work just fine – but Flux provides various <em>optimisers</em> that make it more convenient.</p><pre><codeclass="language-julia">opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1
opt() # Carry out the update, modifying `W` and `b`.</code></pre><p>An optimiser takes a parameter list and returns a function that does the same thing as <code>update</code> above. We can pass either <code>opt</code> or <code>update</code> to our <ahref="training.html">training loop</a>, which will then run the optimiser after every mini-batch of data.</p><h2><aclass="nav-anchor"id="Optimiser-Reference-1"href="#Optimiser-Reference-1">Optimiser Reference</a></h2><p>All optimisers return a function that, when called, will update the parameters passed to it.</p><sectionclass="docstring"><divclass="docstring-header"><aclass="docstring-binding"id="Flux.Optimise.SGD"href="#Flux.Optimise.SGD"><code>Flux.Optimise.SGD</code></a> — <spanclass="docstring-category">Function</span>.</div><div><div><pre><codeclass="language-none">SGD(params, η = 0.1; decay = 0)</code></pre><p>Classic gradient descent optimiser with learning rate <code>η</code>. For each parameter <code>p</code> and its gradient <code>δp</code>, this runs <code>p -= η*δp</code>.</p><p>Supports inverse decaying learning rate if the <code>decay</code> argument is provided.</p></div></div><aclass="source-link"target="_blank"href="https://github.com/FluxML/Flux.jl/blob/9d563820f8f2f5ae91364afc1e9f371f75466e77/src/optimise/interface.jl#L14-L21">source</a></section><sectionclass="docstring"><divclass="docstring-header"><aclass="docstring-binding"id="Flux.Optimise.Momentum"href="#Flux.Optimise.Momentum"><code>Flux.Optimise.Momentum</code></a> — <spanclass="docstring-category">Function</span>.</div><div><div><pre><codeclass="language-none">Momentum(params, η = 0.01; ρ = 0.9, decay = 0)</code></pre><p>SGD with learning rate <code>η</code>, momentum <code>ρ</code> and optional learning rate inverse decay.</p></div></div><aclass="source-link"target="_blank"href="https://github.com/FluxML/Flux.jl/blob/9d563820f8f2f5ae91364afc1e9f371f75466e77/src/optimise/interface.jl#L25-L29">source</a></section><sectionclass="docstring"><divclass="docstring-header"><aclass="docstring-binding"id="Flux.Optimise.Nesterov"href="#Flux.Optimise.Nesterov"><code>Flux.Optimise.Nesterov</code></a> — <spanclass="docstring-category">Function</span>.</div><div><div><pre><codeclass="language-none">Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)</code></pre><p>SGD with learning rate <code>η</code>, Nesterov momentum <code>ρ</code> and optional learning rate inverse decay.</p></div></div><aclass="source-link"target="_blank"href="https://github.com/FluxML/Flux.jl/blob/9d563820f8f2f5ae91364afc1e9f371f75466e77/src/optimise/interface.jl#L33-L37">source</a></section><sectionclass="docstring"><divclass="docstring-header"><aclass="docstring-binding"id="Flux.Optimise.ADAM"href="#Flux.Optimise.ADAM"><code>Flux.Optimise.ADAM</code></a> — <spanclass="docstring-category">Function</span>.</div><div><div><pre><codeclass="language-none">ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)</code></pre><p><ahref="https://arxiv.org/abs/1412.6980v8">ADAM</a> optimiser.</p></div></div><aclass="source-link"target="_blank"href="https://github.com/FluxML/Flux.jl/blob/9d563820f8f2f5ae91364afc1e9f371f75466e77/src/optimise/interface.jl#L51-L55">source</a></section><footer><hr/><aclass="previous"href="../models/layers.html"><spanclass="direction">Previous</span><spanclass="title">Model Reference</span></a><aclass="next"href="training.html"><spanclass="direction">Next</span><spanclass="title">Training</span></a></footer></article></body></html>