From b33c4b49be2bec0fd0034beee3c3d24d7bec289b Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Mon, 20 Apr 2020 19:41:10 +0800 Subject: [PATCH] add ClipValue and ClipNorm --- src/Flux.jl | 6 ++++-- src/optimise/optimisers.jl | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index f973dc4c..0195cc8c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -3,7 +3,8 @@ module Flux # Zero Flux Given using Base: tail -using Zygote, MacroTools, Juno, Reexport, Statistics, Random +using Statistics, Random, LinearAlgebra +using Zygote, MacroTools, Juno, Reexport, Requires using MacroTools: @forward @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd @@ -20,7 +21,8 @@ using .Optimise using .Optimise: @epochs export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, - ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay + ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay, + ClipValue, ClipNorm using CuArrays diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 611edddb..3731e8e3 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -533,3 +533,37 @@ function apply!(o::WeightDecay, x, Δ) wd = o.wd @. Δ += wd * x end + +""" + ClipValue(thresh) + +Clip gradients when their absolute value exceeds `thresh`. + +# Parameters +- Clipping threshold (`thresh`) +""" +mutable struct ClipValue{T} + thresh::T +end + +apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh) + +""" + ClipNorm(thresh) + +Clip gradients when their L2 norm exceeds `thresh`. + +# Parameters +- Clipping threshold (`thresh`) +""" +mutable struct ClipNorm{T} + thresh::T +end + +function apply!(o::ClipNorm, x, Δ) + Δnrm = norm(Δ, 2) + if Δnrm > o.thresh + rmul!(Δ, o.thresh / Δnrm) + end + return Δ +end \ No newline at end of file