add ClipValue and ClipNorm
This commit is contained in:
parent
427c55af92
commit
b33c4b49be
|
@ -3,7 +3,8 @@ module Flux
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
using Base: tail
|
using Base: tail
|
||||||
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
using Statistics, Random, LinearAlgebra
|
||||||
|
using Zygote, MacroTools, Juno, Reexport, Requires
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||||
|
@ -20,7 +21,8 @@ using .Optimise
|
||||||
using .Optimise: @epochs
|
using .Optimise: @epochs
|
||||||
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
|
||||||
|
ClipValue, ClipNorm
|
||||||
|
|
||||||
|
|
||||||
using CuArrays
|
using CuArrays
|
||||||
|
|
|
@ -533,3 +533,37 @@ function apply!(o::WeightDecay, x, Δ)
|
||||||
wd = o.wd
|
wd = o.wd
|
||||||
@. Δ += wd * x
|
@. Δ += wd * x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
ClipValue(thresh)
|
||||||
|
|
||||||
|
Clip gradients when their absolute value exceeds `thresh`.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
- Clipping threshold (`thresh`)
|
||||||
|
"""
|
||||||
|
mutable struct ClipValue{T}
|
||||||
|
thresh::T
|
||||||
|
end
|
||||||
|
|
||||||
|
apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
|
||||||
|
|
||||||
|
"""
|
||||||
|
ClipNorm(thresh)
|
||||||
|
|
||||||
|
Clip gradients when their L2 norm exceeds `thresh`.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
- Clipping threshold (`thresh`)
|
||||||
|
"""
|
||||||
|
mutable struct ClipNorm{T}
|
||||||
|
thresh::T
|
||||||
|
end
|
||||||
|
|
||||||
|
function apply!(o::ClipNorm, x, Δ)
|
||||||
|
Δnrm = norm(Δ, 2)
|
||||||
|
if Δnrm > o.thresh
|
||||||
|
rmul!(Δ, o.thresh / Δnrm)
|
||||||
|
end
|
||||||
|
return Δ
|
||||||
|
end
|
Loading…
Reference in New Issue