Compare commits
1 Commits
master
...
ox/perfdoc
Author | SHA1 | Date |
---|---|---|
![]() |
f87257f043 |
|
@ -4,6 +4,59 @@ All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/ma
|
||||||
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
||||||
Below follow some Flux specific tips/reminders.
|
Below follow some Flux specific tips/reminders.
|
||||||
|
|
||||||
|
## Don't write loss functions that use a non-constant globally declared model.
|
||||||
|
This is a special case of one of the most important [Julia Performance Tips](https://docs.julialang.org/en/v1/manual/performance-tips/#Avoid-global-variables-1).
|
||||||
|
Non-constant global variables are slow.
|
||||||
|
We repeat it here as it is a common mistake.
|
||||||
|
|
||||||
|
This advice is appliable also to writing callbacks, and more generally to all Julia code.
|
||||||
|
|
||||||
|
### Don't write:
|
||||||
|
```julia
|
||||||
|
data = ...
|
||||||
|
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||||
|
loss(x, y) = Flux.mse(m(x), y)
|
||||||
|
|
||||||
|
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
|
||||||
|
```
|
||||||
|
In this bad code, the model `m` is a non-constant global.
|
||||||
|
It is being used inside the function `loss`, which is one of the most performance critical parts of this code.
|
||||||
|
It will be slow, as the compiler can't rely on `m` always being the same type -- it is a mutable global, it could change at any time.
|
||||||
|
|
||||||
|
### Correct alternatives:
|
||||||
|
#### Mark the model `const`
|
||||||
|
```julia
|
||||||
|
data = ...
|
||||||
|
const m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||||
|
loss(x, y) = Flux.mse(m(x), y)
|
||||||
|
|
||||||
|
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
|
||||||
|
```
|
||||||
|
Similarly anything else that is a non-constant global that is used in functions should also be made constant
|
||||||
|
|
||||||
|
#### Put everything in a main function:
|
||||||
|
For more flexibility, you could even make this take `m` as a argument -- it doesn't matter of `m` was originally declared as a non-const global once it has been passed in as a argument because it then becomes a local variable.
|
||||||
|
```julia
|
||||||
|
function main(data)
|
||||||
|
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||||
|
loss(x, y) = Flux.mse(m(x), y)
|
||||||
|
|
||||||
|
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Make the loss function actually close over `m`.
|
||||||
|
Closures can be very useful.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
data = ...
|
||||||
|
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
|
||||||
|
get_loss_function(mdl) = (x, y) -> Flux.mse(mdl(x), y)
|
||||||
|
|
||||||
|
Flux.train!(get_loss_function(m), Flux.params(m), data, Descent(0.1))
|
||||||
|
```
|
||||||
|
This example is particularly applicable to callbacks.
|
||||||
|
|
||||||
## Don't use more precision than you need
|
## Don't use more precision than you need
|
||||||
|
|
||||||
Flux works great with all kinds of number types.
|
Flux works great with all kinds of number types.
|
||||||
|
|
Loading…
Reference in New Issue