This commit is contained in:
Mike J Innes 2018-03-05 17:24:46 +00:00
parent eaa9fd2dd3
commit 662439c164
3 changed files with 9 additions and 9 deletions

View File

@ -1,7 +1,7 @@
# Regularisation
Applying regularisation to model parameters is straightforward. We just need to
apply an appropriate regulariser, such as `norm`, to each model parameter and
apply an appropriate regulariser, such as `vecnorm`, to each model parameter and
add the result to the overall loss.
For example, say we have a simple regression.
@ -14,12 +14,12 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
```julia
penalty() = norm(m.W) + norm(m.b)
penalty() = vecnorm(m.W) + vecnorm(m.b)
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
```
When working with layers, Flux provides the `params` function to grab all
parameters at once. We can easily penalise everything with `sum(norm, params)`.
parameters at once. We can easily penalise everything with `sum(vecnorm, params)`.
```julia
julia> params(m)
@ -27,7 +27,7 @@ julia> params(m)
param([0.355408 0.533092; … 0.430459 0.171498])
param([0.0, 0.0, 0.0, 0.0, 0.0])
julia> sum(norm, params(m))
julia> sum(vecnorm, params(m))
26.01749952921026 (tracked)
```
@ -39,7 +39,7 @@ m = Chain(
Dense(128, 32, relu),
Dense(32, 10), softmax)
loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m))
loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
loss(rand(28^2), rand(10))
```

View File

@ -160,10 +160,8 @@ Base.std(x::TrackedArray; mean = Base.mean(x)) =
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
Base.norm(x::TrackedArray, p::Real = 2) =
p == 1 ? sum(abs.(x)) :
p == 2 ? sqrt(sum(abs2.(x) .+ 1e-6)) :
error("$p-norm not supported")
Base.vecnorm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =

View File

@ -56,6 +56,8 @@ end
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))
@test gradtest(vecnorm, rand(5))
@test gradtest(rand(5)) do x
y = x.^2
2y + x