<htmllang="en"><head><metacharset="UTF-8"/><metaname="viewport"content="width=device-width, initial-scale=1.0"/><title>GPU Support · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
</script><linkhref="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css"rel="stylesheet"type="text/css"/><linkhref="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css"rel="stylesheet"type="text/css"/><linkhref="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css"rel="stylesheet"type="text/css"/><script>documenterBaseURL=".."</script><scriptsrc="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js"data-main="../assets/documenter.js"></script><scriptsrc="../siteinfo.js"></script><scriptsrc="../../versions.js"></script><linkhref="../assets/documenter.css"rel="stylesheet"type="text/css"/><linkhref="../assets/flux.css"rel="stylesheet"type="text/css"/></head><body><navclass="toc"><h1>Flux</h1><selectid="version-selector"onChange="window.location.href=this.value"style="visibility: hidden"></select><formclass="search"id="search-form"action="../search/"><inputid="search-query"name="q"type="text"placeholder="Search docs"/></form><ul><li><aclass="toctext"href="../">Home</a></li><li><spanclass="toctext">Building Models</span><ul><li><aclass="toctext"href="../models/basics/">Basics</a></li><li><aclass="toctext"href="../models/recurrence/">Recurrence</a></li><li><aclass="toctext"href="../models/regularisation/">Regularisation</a></li><li><aclass="toctext"href="../models/layers/">Model Reference</a></li></ul></li><li><spanclass="toctext">Training Models</span><ul><li><aclass="toctext"href="../training/optimisers/">Optimisers</a></li><li><aclass="toctext"href="../training/training/">Training</a></li></ul></li><li><aclass="toctext"href="../data/onehot/">One-Hot Encoding</a></li><liclass="current"><aclass="toctext"href>GPU Support</a><ulclass="internal"><li><aclass="toctext"href="#GPU-Usage-1">GPU Usage</a></li></ul></li><li><aclass="toctext"href="../saving/">Saving & Loading</a></li><li><aclass="toctext"href="../performance/">Performance Tips</a></li><li><aclass="toctext"href="../community/">Community</a></li></ul></nav><articleid="docs"><header><nav><ul><li><ahref>GPU Support</a></li></ul><aclass="edit-page"href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/gpu.md"><spanclass="fa"></span> Edit on GitHub</a></nav><hr/><divid="topbar"><span>GPU Support</span><aclass="fa fa-bars"href="#"></a></div></header><h1><aclass="nav-anchor"id="GPU-Support-1"href="#GPU-Support-1">GPU Support</a></h1><p>NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the <ahref="https://github.com/JuliaGPU/CuArrays.jl">CuArrays</a> readme.</p><h2><aclass="nav-anchor"id="GPU-Usage-1"href="#GPU-Usage-1">GPU Usage</a></h2><p>Support for array operations on other hardware backends, like GPUs, is provided by external packages like <ahref="https://github.com/JuliaGPU/CuArrays.jl">CuArrays</a>. Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.</p><p>For example, we can use <code>CuArrays</code> (with the <code>cu</code> converter) to run our <ahref="../models/basics/">basic example</a> on an NVIDIA GPU.</p><p>(Note that you need to have CUDA available to use CuArrays – please see the <ahref="https://github.com/JuliaGPU/CuArrays.jl">CuArrays.jl</a> instructions for more details.)</p><pre><codeclass="language-julia">using CuArrays
loss(x, y) # ~ 3</code></pre><p>Note that we convert both the parameters (<code>W</code>, <code>b</code>) and the data set (<code>x</code>, <code>y</code>) to cuda arrays. Taking derivatives and training works exactly as before.</p><p>If you define a structured model, like a <code>Dense</code> layer or <code>Chain</code>, you just need to convert the internal parameters. Flux provides <code>mapleaves</code>, which allows you to alter all parameters of a model at once.</p><pre><codeclass="language-julia">d = Dense(10, 5, σ)
d = mapleaves(cu, d)
d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapleaves(cu, m)
d(cu(rand(10)))</code></pre><p>As a convenience, Flux provides the <code>gpu</code> function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading <code>CuArrays</code> will cause it to move data to the GPU instead.</p><pre><codeclass="language-julia">julia> using Flux, CuArrays
julia> m = Dense(10,5) |> gpu
Dense(10, 5)
julia> x = rand(10) |> gpu
10-element CuArray{Float32,1}:
0.800225
⋮
0.511655
julia> m(x)
Tracked 5-element CuArray{Float32,1}:
-0.30535
⋮
-0.618002</code></pre><p>The analogue <code>cpu</code> is also available for moving models and data back off of the GPU.</p><pre><codeclass="language-julia">julia> x = rand(10) |> gpu
10-element CuArray{Float32,1}:
0.235164
⋮
0.192538
julia> x |> cpu
10-element Array{Float32,1}:
0.235164
⋮
0.192538</code></pre><footer><hr/><aclass="previous"href="../data/onehot/"><spanclass="direction">Previous</span><spanclass="title">One-Hot Encoding</span></a><aclass="next"href="../saving/"><spanclass="direction">Next</span><spanclass="title">Saving & Loading</span></a></footer></article></body></html>