This commit is contained in:
Mike J Innes 2018-03-05 19:23:13 +00:00
parent 662439c164
commit 6b69edfe6c
1 changed files with 21 additions and 2 deletions

View File

@ -1,6 +1,6 @@
# GPU Support
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl) and [CLArrays](https://github.com/JuliaGPU/CLArrays.jl). Flux doesn't care what array type you use, so we can just plug these in without any other changes.
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
@ -32,4 +32,23 @@ m = mapleaves(cu, m)
d(cu(rand(10)))
```
The [mnist example](https://github.com/FluxML/model-zoo/blob/master/mnist/mlp.jl) contains the code needed to run the model on the GPU; just uncomment the lines after `using CuArrays`.
As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading `CuArrays` will cause it to move data to the GPU instead.
```julia
julia> using Flux, CuArrays
julia> m = Dense(10,5) |> gpu
Dense(10, 5)
julia> x = rand(10) |> gpu
10-element CuArray{Float32,1}:
0.800225
0.511655
julia> m(x)
Tracked 5-element CuArray{Float32,1}:
-0.30535
-0.618002
```