This commit is contained in:
Yao Lu 2020-04-22 01:22:34 +08:00
parent def19b058e
commit 1dfec7f38b
2 changed files with 17 additions and 2 deletions

View File

@ -1,9 +1,12 @@
module Optimise
using LinearAlgebra
export train!, update!,
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM,
InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm,
stop, Optimiser
include("optimisers.jl")
include("train.jl")

View File

@ -89,3 +89,15 @@ end
@test decay_steps == ground_truth
@test o.eta == o.clip
end
@testset "Clipping" begin
w = randn(10, 10)
loss(x) = sum(w * x)
θ = Params([w])
x = 1000 * randn(10)
= gradient(() -> loss(x), θ)[w]
w̄_value = Optimise.apply!(ClipValue(1.0), w, copy())
@test all(w̄_value .<= 1)
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy())
@test norm(w̄_norm) <= 1
end