add ClipValue and ClipNorm
This commit is contained in:
parent
427c55af92
commit
b33c4b49be
|
@ -3,7 +3,8 @@ module Flux
|
|||
# Zero Flux Given
|
||||
|
||||
using Base: tail
|
||||
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
||||
using Statistics, Random, LinearAlgebra
|
||||
using Zygote, MacroTools, Juno, Reexport, Requires
|
||||
using MacroTools: @forward
|
||||
@reexport using NNlib
|
||||
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||
|
@ -20,7 +21,8 @@ using .Optimise
|
|||
using .Optimise: @epochs
|
||||
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
|
||||
ClipValue, ClipNorm
|
||||
|
||||
|
||||
using CuArrays
|
||||
|
|
|
@ -533,3 +533,37 @@ function apply!(o::WeightDecay, x, Δ)
|
|||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
||||
"""
|
||||
ClipValue(thresh)
|
||||
|
||||
Clip gradients when their absolute value exceeds `thresh`.
|
||||
|
||||
# Parameters
|
||||
- Clipping threshold (`thresh`)
|
||||
"""
|
||||
mutable struct ClipValue{T}
|
||||
thresh::T
|
||||
end
|
||||
|
||||
apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
|
||||
|
||||
"""
|
||||
ClipNorm(thresh)
|
||||
|
||||
Clip gradients when their L2 norm exceeds `thresh`.
|
||||
|
||||
# Parameters
|
||||
- Clipping threshold (`thresh`)
|
||||
"""
|
||||
mutable struct ClipNorm{T}
|
||||
thresh::T
|
||||
end
|
||||
|
||||
function apply!(o::ClipNorm, x, Δ)
|
||||
Δnrm = norm(Δ, 2)
|
||||
if Δnrm > o.thresh
|
||||
rmul!(Δ, o.thresh / Δnrm)
|
||||
end
|
||||
return Δ
|
||||
end
|
Loading…
Reference in New Issue