checkpointing

This commit is contained in:
Mike Innes 2018-03-06 03:01:40 +00:00
parent 432b9c3222
commit 646e90aae2

View File

@ -65,14 +65,46 @@ You can easily load parameters back into a model with `Flux.loadparams!`.
```julia ```julia
julia> using Flux julia> using Flux
julia> using BSON: @load
julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax) julia> model = Chain(Dense(10,5,relu),Dense(5,2),softmax)
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
julia> using BSON: @load
julia> @load "mymodel.bson" weights julia> @load "mymodel.bson" weights
julia> Flux.loadparams!(model, weights) julia> Flux.loadparams!(model, weights)
``` ```
The new `model` we created will now be identical to the one we saved parameters for. The new `model` we created will now be identical to the one we saved parameters for.
## Checkpointing
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).
```julia
using Flux: throttle
using BSON: @save
m = Chain(Dense(10,5,relu),Dense(5,2),softmax)
evalcb = throttle(30) do
# Show loss
@save "model-checkpoint.bson" model
end
```
This will update the `"model-checkpoint.bson"` file every thirty seconds.
You can get more advanced by saving a series of models throughout training, for example
```julia
@save "model-$(now()).bson" model
```
will produce a series of models like `"model-2018-03-06T02:57:10.41.bson"`. You
could also store the current test set loss, so that it's easy to (for example)
revert to an older copy of the model if it starts to overfit.
```julia
bson("model-$(now()).bson", model = model, loss = testloss())
```