add test
This commit is contained in:
parent
def19b058e
commit
1dfec7f38b
|
@ -1,9 +1,12 @@
|
||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
|
using LinearAlgebra
|
||||||
|
|
||||||
export train!, update!,
|
export train!, update!,
|
||||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM,
|
||||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm,
|
||||||
|
stop, Optimiser
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
|
|
@ -89,3 +89,15 @@ end
|
||||||
@test decay_steps == ground_truth
|
@test decay_steps == ground_truth
|
||||||
@test o.eta == o.clip
|
@test o.eta == o.clip
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Clipping" begin
|
||||||
|
w = randn(10, 10)
|
||||||
|
loss(x) = sum(w * x)
|
||||||
|
θ = Params([w])
|
||||||
|
x = 1000 * randn(10)
|
||||||
|
w̄ = gradient(() -> loss(x), θ)[w]
|
||||||
|
w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄))
|
||||||
|
@test all(w̄_value .<= 1)
|
||||||
|
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄))
|
||||||
|
@test norm(w̄_norm) <= 1
|
||||||
|
end
|
Loading…
Reference in New Issue