add ClipValue and ClipNorm

This commit is contained in:
Yao Lu 2020-04-20 19:41:10 +08:00
parent 427c55af92
commit b33c4b49be
2 changed files with 38 additions and 2 deletions

View File

@ -3,7 +3,8 @@ module Flux
# Zero Flux Given # Zero Flux Given
using Base: tail using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random using Statistics, Random, LinearAlgebra
using Zygote, MacroTools, Juno, Reexport, Requires
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd using Zygote: Params, @adjoint, gradient, pullback, @nograd
@ -20,7 +21,8 @@ using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
ClipValue, ClipNorm
using CuArrays using CuArrays

View File

@ -533,3 +533,37 @@ function apply!(o::WeightDecay, x, Δ)
wd = o.wd wd = o.wd
@. Δ += wd * x @. Δ += wd * x
end end
"""
ClipValue(thresh)
Clip gradients when their absolute value exceeds `thresh`.
# Parameters
- Clipping threshold (`thresh`)
"""
mutable struct ClipValue{T}
thresh::T
end
apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
"""
ClipNorm(thresh)
Clip gradients when their L2 norm exceeds `thresh`.
# Parameters
- Clipping threshold (`thresh`)
"""
mutable struct ClipNorm{T}
thresh::T
end
function apply!(o::ClipNorm, x, Δ)
Δnrm = norm(Δ, 2)
if Δnrm > o.thresh
rmul!(Δ, o.thresh / Δnrm)
end
return Δ
end