diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 28a1849d..184d472c 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,9 +1,12 @@ module Optimise +using LinearAlgebra + export train!, update!, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, - InvDecay, ExpDecay, WeightDecay, stop, Optimiser + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, + InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm, + stop, Optimiser include("optimisers.jl") include("train.jl") diff --git a/test/optimise.jl b/test/optimise.jl index ac131b96..b3a0250c 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -89,3 +89,15 @@ end @test decay_steps == ground_truth @test o.eta == o.clip end + +@testset "Clipping" begin + w = randn(10, 10) + loss(x) = sum(w * x) + θ = Params([w]) + x = 1000 * randn(10) + w̄ = gradient(() -> loss(x), θ)[w] + w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) + @test all(w̄_value .<= 1) + w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) + @test norm(w̄_norm) <= 1 +end \ No newline at end of file