diff --git a/src/Flux.jl b/src/Flux.jl index f973dc4c..0195cc8c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -3,7 +3,8 @@ module Flux # Zero Flux Given using Base: tail -using Zygote, MacroTools, Juno, Reexport, Statistics, Random +using Statistics, Random, LinearAlgebra +using Zygote, MacroTools, Juno, Reexport, Requires using MacroTools: @forward @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd @@ -20,7 +21,8 @@ using .Optimise using .Optimise: @epochs export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, - ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay + ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay, + ClipValue, ClipNorm using CuArrays diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 611edddb..3731e8e3 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -533,3 +533,37 @@ function apply!(o::WeightDecay, x, Δ) wd = o.wd @. Δ += wd * x end + +""" + ClipValue(thresh) + +Clip gradients when their absolute value exceeds `thresh`. + +# Parameters +- Clipping threshold (`thresh`) +""" +mutable struct ClipValue{T} + thresh::T +end + +apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh) + +""" + ClipNorm(thresh) + +Clip gradients when their L2 norm exceeds `thresh`. + +# Parameters +- Clipping threshold (`thresh`) +""" +mutable struct ClipNorm{T} + thresh::T +end + +function apply!(o::ClipNorm, x, Δ) + Δnrm = norm(Δ, 2) + if Δnrm > o.thresh + rmul!(Δ, o.thresh / Δnrm) + end + return Δ +end \ No newline at end of file