<htmllang="en"><head><metacharset="UTF-8"/><metaname="viewport"content="width=device-width, initial-scale=1.0"/><title>GPU Support · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
</script><linkhref="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/fontawesome.min.css"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/solid.min.css"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/brands.min.css"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.11.1/katex.min.css"rel="stylesheet"type="text/css"/><script>documenterBaseURL=".."</script><scriptsrc="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"data-main="../assets/documenter.js"></script><scriptsrc="../siteinfo.js"></script><scriptsrc="../../versions.js"></script><linkhref="../assets/flux.css"rel="stylesheet"type="text/css"/><linkclass="docs-theme-link"rel="stylesheet"type="text/css"href="../assets/themes/documenter-dark.css"data-theme-name="documenter-dark"/><linkclass="docs-theme-link"rel="stylesheet"type="text/css"href="../assets/themes/documenter-light.css"data-theme-name="documenter-light"data-theme-primary/><scriptsrc="../assets/themeswap.js"></script></head><body><divid="documenter"><navclass="docs-sidebar"><divclass="docs-package-name"><spanclass="docs-autofit">Flux</span></div><formclass="docs-search"action="../search/"><inputclass="docs-search-query"id="documenter-search-query"name="q"type="text"placeholder="Search docs"/></form><ulclass="docs-menu"><li><aclass="tocitem"href="../">Home</a></li><li><spanclass="tocitem">Building Models</span><ul><li><aclass="tocitem"href="../models/basics/">Basics</a></li><li><aclass="tocitem"href="../models/recurrence/">Recurrence</a></li><li><aclass="tocitem"href="../models/regularisation/">Regularisation</a></li><li><aclass="tocitem"href="../models/layers/">Model Reference</a></li><li><aclass="tocitem"href="../models/nnlib/">NNlib</a></li></ul></li><li><spanclass="tocitem">Handling Data</span><ul><li><aclass="tocitem"href="../data/onehot/">One-Hot Encoding</a></li><li><aclass="tocitem"href="../data/dataloader/">DataLoader</a></li></ul></li><li><spanclass="tocitem">Training Models</span><ul><li><aclass="tocitem"href="../training/optimisers/">Optimisers</a></li><li><aclass="tocitem"href="../training/training/">Training</a></li></ul></li><liclass="is-active"><aclass="tocitem"href>GPU Support</a><ulclass="internal"><li><aclass="tocitem"href="#GPU-Usage-1"><span>GPU Usage</span></a></li></ul></li><li><aclass="tocitem"href="../saving/">Saving & Loading</a></li><li><aclass="tocitem"href="../ecosystem/">The Julia Ecosystem</a></li><li><aclass="tocitem"href="../performance/">Performance Tips</a></li><li><aclass="tocitem"href="../community/">Community</a></li></ul><divclass="docs-version-selector field has-addons"><divclass="control"><spanclass="docs-label button is-static is-size-7">Version</span></div><divclass="docs-selector control is-expanded"><divclass="select is-fullwidth is-size-7"><selectid="documenter-version-selector"></select></div></div></div></nav><divclass="docs-main"><headerclass="docs-navbar"><navclass="breadcrumb"><ulclass="is-hidden-mobile"><liclass="is-active"><ahref>GPU Support</a></li></ul><ulclass="is-hidden-tablet"><liclass="is-active"><ahref>GPU Support</a></li></ul></nav><divclass="docs-right"><aclass="docs-edit-link"href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/gpu.md"title="Edit on GitHub"><spanclass="docs-icon fab"></span><spanclass="docs-label is-hidden-touch">Edit on GitHub</span></a><aclass="docs-settings-button fas fa-cog"id="documenter-settings-button"href="#"title="Settings"></a><aclass="docs-sidebar-button fa fa-bars is-hidden-desktop"id="documenter-sidebar-button"href="#"></a></div></header><articleclass="content"id="documenter-page"><h1id="GPU-Support-1"><aclass="docs-heading-anchor"href="#GPU-Support-1">GPU Support</a><aclass="docs-heading-anchor-perma
W = cu(rand(2, 5)) # a 2×5 CuArray
b = cu(rand(2))
predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)
x, y = cu(rand(5)), cu(rand(2)) # Dummy data
loss(x, y) # ~ 3</code></pre><p>Note that we convert both the parameters (<code>W</code>, <code>b</code>) and the data set (<code>x</code>, <code>y</code>) to cuda arrays. Taking derivatives and training works exactly as before.</p><p>If you define a structured model, like a <code>Dense</code> layer or <code>Chain</code>, you just need to convert the internal parameters. Flux provides <code>fmap</code>, which allows you to alter all parameters of a model at once.</p><pre><codeclass="language-julia">d = Dense(10, 5, σ)
d(cu(rand(10)))</code></pre><p>However, if you create a customized model, <code>fmap</code> may not work out of the box.</p><pre><codeclass="language-julia">julia> struct ActorCritic{A, C}
ActorCritic{Array{Float64,2},Array{Float64,1}}([1.0 1.0; 1.0 1.0], [1.0, 1.0])</code></pre><p>As you can see, nothing changed after <code>fmap(cu, m)</code>. The reason is that <code>Flux</code> doesn't know your customized model structure. To make it work as expected, you need the <code>@functor</code> macro.</p><pre><codeclass="language-julia">julia> Flux.@functor ActorCritic
julia> fmap(cu, m)
ActorCritic{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}(Float32[1.0 1.0; 1.0 1.0], Float32[1.0, 1.0])</code></pre><p>Now you can see that the inner fields of <code>actor</code> and <code>critic</code> are transformed into <code>CuArray</code>. So what does the <code>@functor</code> macro do here? Basically, it will create a function like this:</p><pre><codeclass="language-julia">Flux.functor(m::ActorCritic) = (actor = m.actor, critic=m.critic), fields -> ActorCritic(fields...)</code></pre><p>And the <code>functor</code> will be called recursively in <code>fmap</code>. As you can see, the result of <code>functor</code> contains two parts, a <em>destructure</em> part and a <em>reconstrucutre</em> part. The first part is to make the customized model structure into <code>trainable</code> data structure known to <code>Flux</code> (here is a <code>NamedTuple</code>). The goal is to turn <code>m</code> into <code>(actor=cu(ones(2,2)), critic=cu(ones(2)))</code>. The second part is to turn the result back into a <code>ActorCritic</code>, so that we can get <code>ActorCritic(cu(ones(2,2)),cu(ones(2)))</code>.</p><p>By default, the <code>@functor</code> macro will transform all the fields in your customized structure. In some cases, you may only want to transform several fields. Then you just specify those fields manually like <code>Flux.@functor ActorCritic (actor,)</code> (note that the fields part must be a tuple). And make sure the <code>ActorCritic(actor)</code> constructor is also implemented.</p><p>As a convenience, Flux provides the <code>gpu</code> function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading <code>CuArrays</code> will cause it to move data to the GPU instead.</p><pre><codeclass="language-julia">julia> using Flux, CuArrays
-0.618002</code></pre><p>The analogue <code>cpu</code> is also available for moving models and data back off of the GPU.</p><pre><codeclass="language-julia">julia> x = rand(10) |> gpu
0.192538</code></pre></article><navclass="docs-footer"><aclass="docs-footer-prevpage"href="../training/training/">« Training</a><aclass="docs-footer-nextpage"href="../saving/">Saving & Loading »</a></nav></div><divclass="modal"id="documenter-settings"><divclass="modal-background"></div><divclass="modal-card"><headerclass="modal-card-head"><pclass="modal-card-title">Settings</p><buttonclass="delete"></button></header><sectionclass="modal-card-body"><p><labelclass="label">Theme</label><divclass="select"><selectid="documenter-themepicker"><optionvalue="documenter-light">documenter-light</option><optionvalue="documenter-dark">documenter-dark</option></select></div></p><hr/><p>This document was generated with <ahref="https://github.com/JuliaDocs/Documenter.jl">Documenter.jl</a> on <spanclass="colophon-date"title="Tuesday 3 March 2020 17:50">Tuesday 3 March 2020</span>. Using Julia version 1.3.1.</p></section><footerclass="modal-card-foot"></footer></div></div></div></body></html>