add test
This commit is contained in:
parent
def19b058e
commit
1dfec7f38b
|
@ -1,9 +1,12 @@
|
|||
module Optimise
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
export train!, update!,
|
||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM,
|
||||
InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm,
|
||||
stop, Optimiser
|
||||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
|
|
|
@ -89,3 +89,15 @@ end
|
|||
@test decay_steps == ground_truth
|
||||
@test o.eta == o.clip
|
||||
end
|
||||
|
||||
@testset "Clipping" begin
|
||||
w = randn(10, 10)
|
||||
loss(x) = sum(w * x)
|
||||
θ = Params([w])
|
||||
x = 1000 * randn(10)
|
||||
w̄ = gradient(() -> loss(x), θ)[w]
|
||||
w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄))
|
||||
@test all(w̄_value .<= 1)
|
||||
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄))
|
||||
@test norm(w̄_norm) <= 1
|
||||
end
|
Loading…
Reference in New Issue