add ClipValue and ClipNorm

This commit is contained in:
Yao Lu 2020-04-20 19:41:10 +08:00
parent 427c55af92
commit b33c4b49be
2 changed files with 38 additions and 2 deletions

View File

@ -3,7 +3,8 @@ module Flux
# Zero Flux Given
using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using Statistics, Random, LinearAlgebra
using Zygote, MacroTools, Juno, Reexport, Requires
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd
@ -20,7 +21,8 @@ using .Optimise
using .Optimise: @epochs
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
ClipValue, ClipNorm
using CuArrays

View File

@ -533,3 +533,37 @@ function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end
"""
ClipValue(thresh)
Clip gradients when their absolute value exceeds `thresh`.
# Parameters
- Clipping threshold (`thresh`)
"""
mutable struct ClipValue{T}
thresh::T
end
apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
"""
ClipNorm(thresh)
Clip gradients when their L2 norm exceeds `thresh`.
# Parameters
- Clipping threshold (`thresh`)
"""
mutable struct ClipNorm{T}
thresh::T
end
function apply!(o::ClipNorm, x, Δ)
Δnrm = norm(Δ, 2)
if Δnrm > o.thresh
rmul!(Δ, o.thresh / Δnrm)
end
return Δ
end