Compare commits

...

1 Commits

Author SHA1 Message Date
Lyndon White f87257f043
Tell people sto stop using global models 2020-03-14 18:09:06 +00:00
1 changed files with 53 additions and 0 deletions

View File

@ -4,6 +4,59 @@ All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/ma
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
Below follow some Flux specific tips/reminders.
## Don't write loss functions that use a non-constant globally declared model.
This is a special case of one of the most important [Julia Performance Tips](https://docs.julialang.org/en/v1/manual/performance-tips/#Avoid-global-variables-1).
Non-constant global variables are slow.
We repeat it here as it is a common mistake.
This advice is appliable also to writing callbacks, and more generally to all Julia code.
### Don't write:
```julia
data = ...
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y)
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
```
In this bad code, the model `m` is a non-constant global.
It is being used inside the function `loss`, which is one of the most performance critical parts of this code.
It will be slow, as the compiler can't rely on `m` always being the same type -- it is a mutable global, it could change at any time.
### Correct alternatives:
#### Mark the model `const`
```julia
data = ...
const m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y)
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
```
Similarly anything else that is a non-constant global that is used in functions should also be made constant
#### Put everything in a main function:
For more flexibility, you could even make this take `m` as a argument -- it doesn't matter of `m` was originally declared as a non-const global once it has been passed in as a argument because it then becomes a local variable.
```julia
function main(data)
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y)
Flux.train!(loss, Flux.params(m), data, Descent(0.1))
end
```
#### Make the loss function actually close over `m`.
Closures can be very useful.
```julia
data = ...
m = Chain(Dense(784, 32, σ), Dense(32, 10), softmax)
get_loss_function(mdl) = (x, y) -> Flux.mse(mdl(x), y)
Flux.train!(get_loss_function(m), Flux.params(m), data, Descent(0.1))
```
This example is particularly applicable to callbacks.
## Don't use more precision than you need
Flux works great with all kinds of number types.