update freeze docs
This commit is contained in:
parent
94ba1e8ede
commit
12106ff4cc
|
@ -39,23 +39,35 @@ However, doing this requires the `struct` to have a corresponding constructor th
|
|||
|
||||
When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.
|
||||
|
||||
Consider the simple multi-layer model where we want to omit optimising the first two `Dense` layers. This setup would look something like so:
|
||||
Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
|
||||
this using the slicing features `Chain` provides:
|
||||
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(784, 64, σ),
|
||||
Dense(64, 32),
|
||||
Dense(32, 10), softmax)
|
||||
Dense(784, 64, relu),
|
||||
Dense(64, 64, relu),
|
||||
Dense(32, 10)
|
||||
)
|
||||
|
||||
ps = Flux.params(m[3:end])
|
||||
```
|
||||
|
||||
`ps` now holds a reference to only the parameters of the layers passed to it.
|
||||
The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.
|
||||
|
||||
During training, now the gradients would only be applied to the last `Dense` layer (and the `softmax` layer, but that is stateless so doesn't have any parameters), so only that would have its parameters changed.
|
||||
During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.
|
||||
|
||||
`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:
|
||||
|
||||
```julia
|
||||
Flux.params(m[1], m[3:end])
|
||||
```
|
||||
|
||||
Sometimes, a more fine-tuned control is needed.
|
||||
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
|
||||
by simply deleting it from `ps`:
|
||||
|
||||
```julia
|
||||
ps = params(m)
|
||||
delete!(ps, m[2].b)
|
||||
```
|
||||
|
||||
|
|
Loading…
Reference in New Issue