gpu docs
This commit is contained in:
parent
662439c164
commit
6b69edfe6c
|
@ -1,6 +1,6 @@
|
|||
# GPU Support
|
||||
|
||||
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl) and [CLArrays](https://github.com/JuliaGPU/CLArrays.jl). Flux doesn't care what array type you use, so we can just plug these in without any other changes.
|
||||
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
|
||||
|
||||
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
||||
|
||||
|
@ -32,4 +32,23 @@ m = mapleaves(cu, m)
|
|||
d(cu(rand(10)))
|
||||
```
|
||||
|
||||
The [mnist example](https://github.com/FluxML/model-zoo/blob/master/mnist/mlp.jl) contains the code needed to run the model on the GPU; just uncomment the lines after `using CuArrays`.
|
||||
As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading `CuArrays` will cause it to move data to the GPU instead.
|
||||
|
||||
```julia
|
||||
julia> using Flux, CuArrays
|
||||
|
||||
julia> m = Dense(10,5) |> gpu
|
||||
Dense(10, 5)
|
||||
|
||||
julia> x = rand(10) |> gpu
|
||||
10-element CuArray{Float32,1}:
|
||||
0.800225
|
||||
⋮
|
||||
0.511655
|
||||
|
||||
julia> m(x)
|
||||
Tracked 5-element CuArray{Float32,1}:
|
||||
-0.30535
|
||||
⋮
|
||||
-0.618002
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue