Flux.jl/dev/models/advanced/index.html

28 lines
10 KiB
HTML

<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"/><meta name="viewport" content="width=device-width, initial-scale=1.0"/><title>Advanced Model Building · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview', {'page': location.pathname + location.search + location.hash});
</script><link href="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/fontawesome.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/solid.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/brands.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.11.1/katex.min.css" rel="stylesheet" type="text/css"/><script>documenterBaseURL="../.."</script><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js" data-main="../../assets/documenter.js"></script><script src="../../siteinfo.js"></script><script src="../../../versions.js"></script><link href="../../assets/flux.css" rel="stylesheet" type="text/css"/><link class="docs-theme-link" rel="stylesheet" type="text/css" href="../../assets/themes/documenter-dark.css" data-theme-name="documenter-dark"/><link class="docs-theme-link" rel="stylesheet" type="text/css" href="../../assets/themes/documenter-light.css" data-theme-name="documenter-light" data-theme-primary/><script src="../../assets/themeswap.js"></script></head><body><div id="documenter"><nav class="docs-sidebar"><div class="docs-package-name"><span class="docs-autofit">Flux</span></div><form class="docs-search" action="../../search/"><input class="docs-search-query" id="documenter-search-query" name="q" type="text" placeholder="Search docs"/></form><ul class="docs-menu"><li><a class="tocitem" href="../../">Home</a></li><li><span class="tocitem">Building Models</span><ul><li><a class="tocitem" href="../basics/">Basics</a></li><li><a class="tocitem" href="../recurrence/">Recurrence</a></li><li><a class="tocitem" href="../regularisation/">Regularisation</a></li><li><a class="tocitem" href="../layers/">Model Reference</a></li><li class="is-active"><a class="tocitem" href>Advanced Model Building</a><ul class="internal"><li><a class="tocitem" href="#Customising-Parameter-Collection-for-a-Model-1"><span>Customising Parameter Collection for a Model</span></a></li><li><a class="tocitem" href="#Freezing-Layer-Parameters-1"><span>Freezing Layer Parameters</span></a></li></ul></li><li><a class="tocitem" href="../nnlib/">NNlib</a></li></ul></li><li><span class="tocitem">Handling Data</span><ul><li><a class="tocitem" href="../../data/onehot/">One-Hot Encoding</a></li><li><a class="tocitem" href="../../data/dataloader/">DataLoader</a></li></ul></li><li><span class="tocitem">Training Models</span><ul><li><a class="tocitem" href="../../training/optimisers/">Optimisers</a></li><li><a class="tocitem" href="../../training/training/">Training</a></li></ul></li><li><a class="tocitem" href="../../gpu/">GPU Support</a></li><li><a class="tocitem" href="../../saving/">Saving &amp; Loading</a></li><li><a class="tocitem" href="../../ecosystem/">The Julia Ecosystem</a></li><li><a class="tocitem" href="../../utilities/">Utility Functions</a></li><li><a class="tocitem" href="../../performance/">Performance Tips</a></li><li><a class="tocitem" href="../../datasets/">Datasets</a></li><li><a class="tocitem" href="../../community/">Community</a></li></ul><div class="docs-version-selector field has-addons"><div class="control"><span class="docs-label button is-static is-size-7">Version</span></div><div class="docs-selector control is-expanded"><div class="select is-fullwidth is-size-7"><select id="documenter-version-selector"></select></div></div></div></nav><div class="docs-main"><header class="docs-navbar"><nav class="breadcrumb"><ul class="is-hidden-mobile"><li><a class="is-disabled">Building Models</a></li><li class="is-active"><a href>Advanced Model Building</a></li></ul><ul class="is-hidden-tablet"><li class="is-active"><a href>Advanced Model Building</a></li></ul></nav><div class="docs-right"><a class="docs-edit-link" href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/models/advanced.md" title="Edit on GitHub"><span class="docs-icon fab"></span><span class="docs-label is-hidden-touch">Edit on GitHub</span></a><a class="docs-settings-button fas fa-cog" id="documenter-settings-button" href="#" title="Settings"></a><a class="docs-sidebar-button fa fa-bars is-hidden-desktop" id="documenter-sidebar-button" href="#"></a></div></header><article class="content" id="documenter-page"><h1 id="Advanced-Model-Building-and-Customisation-1"><a class="docs-heading-anchor" href="#Advanced-Model-Building-and-Customisation-1">Advanced Model Building and Customisation</a><a class="docs-heading-anchor-permalink" href="#Advanced-Model-Building-and-Customisation-1" title="Permalink"></a></h1><p>Here we will try and describe usage of some more advanced features that Flux provides to give more control over model building.</p><h2 id="Customising-Parameter-Collection-for-a-Model-1"><a class="docs-heading-anchor" href="#Customising-Parameter-Collection-for-a-Model-1">Customising Parameter Collection for a Model</a><a class="docs-heading-anchor-permalink" href="#Customising-Parameter-Collection-for-a-Model-1" title="Permalink"></a></h2><p>Taking reference from our example <code>Affine</code> layer from the <a href="../basics/#Building-Layers-1">basics</a>.</p><p>By default all the fields in the <code>Affine</code> type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our &quot;layers&quot; that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, it is possible to mark the fields of our layers that are trainable in two ways.</p><p>The first way of achieving this is through overloading the <code>trainable</code> function.</p><pre><code class="language-julia-repl">julia&gt; @functor Affine
julia&gt; a = Affine(rand(3,3), rand(3))
Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955])
julia&gt; Flux.params(a) # default behavior
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
julia&gt; Flux.trainable(a::Affine) = (a.W,)
julia&gt; Flux.params(a)
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])</code></pre><p>Only the fields returned by <code>trainable</code> will be collected as trainable parameters of the layer when calling <code>Flux.params</code>.</p><p>Another way of achieving this is through the <code>@functor</code> macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:</p><pre><code class="language-julia">Flux.@functor Affine (W,)</code></pre><p>However, doing this requires the <code>struct</code> to have a corresponding constructor that accepts those parameters.</p><h2 id="Freezing-Layer-Parameters-1"><a class="docs-heading-anchor" href="#Freezing-Layer-Parameters-1">Freezing Layer Parameters</a><a class="docs-heading-anchor-permalink" href="#Freezing-Layer-Parameters-1" title="Permalink"></a></h2><p>When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to <code>params</code>.</p><p>Consider a simple multi-layer perceptron model where we want to avoid optimising the first two <code>Dense</code> layers. We can obtain this using the slicing features <code>Chain</code> provides:</p><pre><code class="language-julia">m = Chain(
Dense(784, 64, relu),
Dense(64, 64, relu),
Dense(32, 10)
)
ps = Flux.params(m[3:end])</code></pre><p>The <code>Zygote.Params</code> object <code>ps</code> now holds a reference to only the parameters of the layers passed to it.</p><p>During training, the gradients will only be computed for (and applied to) the last <code>Dense</code> layer, therefore only that would have its parameters changed.</p><p><code>Flux.params</code> also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second <code>Dense</code> layer in the previous example. It would look something like this:</p><pre><code class="language-julia">Flux.params(m[1], m[3:end])</code></pre><p>Sometimes, a more fine-tuned control is needed. We can freeze a specific parameter of a specific layer which already entered a <code>Params</code> object <code>ps</code>, by simply deleting it from <code>ps</code>:</p><pre><code class="language-julia">ps = params(m)
delete!(ps, m[2].b) </code></pre></article><nav class="docs-footer"><a class="docs-footer-prevpage" href="../layers/">« Model Reference</a><a class="docs-footer-nextpage" href="../nnlib/">NNlib »</a></nav></div><div class="modal" id="documenter-settings"><div class="modal-background"></div><div class="modal-card"><header class="modal-card-head"><p class="modal-card-title">Settings</p><button class="delete"></button></header><section class="modal-card-body"><p><label class="label">Theme</label><div class="select"><select id="documenter-themepicker"><option value="documenter-light">documenter-light</option><option value="documenter-dark">documenter-dark</option></select></div></p><hr/><p>This document was generated with <a href="https://github.com/JuliaDocs/Documenter.jl">Documenter.jl</a> on <span class="colophon-date" title="Wednesday 27 May 2020 11:52">Wednesday 27 May 2020</span>. Using Julia version 1.3.1.</p></section><footer class="modal-card-foot"></footer></div></div></div></body></html>