Compare commits
1 Commits
master
...
ox/perfdoc
Author | SHA1 | Date |
---|---|---|
Lyndon White | f87257f043 |
|
@ -4,6 +4,59 @@ All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/ma
|
|||
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
||||
Below follow some Flux specific tips/reminders.
|
||||
|
||||
## Don't write loss functions that use a non-constant globally declared model.
|
||||
This is a special case of one of the most important [Julia Performance Tips](https://docs.julialang.org/en/v1/manual/performance-tips/#Avoid-global-variables-1).
|
||||
Non-constant global variables are slow.
|
||||
We repeat it here as it is a common mistake.
|
||||
|
||||
This advice is appliable also to writing callbacks, and more generally to all Julia code.
|
||||
|
||||
### Don't write:
|
||||
```julia
|
||||
data = ...
|
||||
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||
loss(x, y) = Flux.mse(m(x), y)
|
||||
|
||||
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
|
||||
```
|
||||
In this bad code, the model `m` is a non-constant global.
|
||||
It is being used inside the function `loss`, which is one of the most performance critical parts of this code.
|
||||
It will be slow, as the compiler can't rely on `m` always being the same type -- it is a mutable global, it could change at any time.
|
||||
|
||||
### Correct alternatives:
|
||||
#### Mark the model `const`
|
||||
```julia
|
||||
data = ...
|
||||
const m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||
loss(x, y) = Flux.mse(m(x), y)
|
||||
|
||||
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
|
||||
```
|
||||
Similarly anything else that is a non-constant global that is used in functions should also be made constant
|
||||
|
||||
#### Put everything in a main function:
|
||||
For more flexibility, you could even make this take `m` as a argument -- it doesn't matter of `m` was originally declared as a non-const global once it has been passed in as a argument because it then becomes a local variable.
|
||||
```julia
|
||||
function main(data)
|
||||
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||
loss(x, y) = Flux.mse(m(x), y)
|
||||
|
||||
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
|
||||
end
|
||||
```
|
||||
|
||||
#### Make the loss function actually close over `m`.
|
||||
Closures can be very useful.
|
||||
|
||||
```julia
|
||||
data = ...
|
||||
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||
get_loss_function(mdl) = (x, y) -> Flux.mse(mdl(x), y)
|
||||
|
||||
Flux.train!(get_loss_function(m), Flux.params(m), data, Descent(0.1))
|
||||
```
|
||||
This example is particularly applicable to callbacks.
|
||||
|
||||
## Don't use more precision than you need
|
||||
|
||||
Flux works great with all kinds of number types.
|
||||
|
|
Loading…
Reference in New Issue