diff --git a/Project.toml b/Project.toml index 1c5bc128..a9af31da 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 1334f07d..0cf165b6 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -140,3 +140,16 @@ ExpDecay InvDecay WeightDecay ``` + +## Gradient Clipping + +Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is + +```julia +opt = Optimiser(ClipValue(1e-3), ADAM(1e-3)) +``` + +```@docs +ClipValue +ClipNorm +``` \ No newline at end of file diff --git a/src/Flux.jl b/src/Flux.jl index 90dcb630..c28a7d36 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -3,7 +3,8 @@ module Flux # Zero Flux Given using Base: tail -using Zygote, MacroTools, Juno, Reexport, Statistics, Random +using Statistics, Random, LinearAlgebra +using Zygote, MacroTools, Juno, Reexport using MacroTools: @forward @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd @@ -20,7 +21,8 @@ using .Optimise using .Optimise: @epochs export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, - ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay + ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay, + ClipValue, ClipNorm using CuArrays diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5ffe4a8f..0f5e644f 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,9 +1,12 @@ module Optimise +using LinearAlgebra + export train!, update!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, - InvDecay, ExpDecay, WeightDecay, stop, Optimiser + InvDecay, ExpDecay, WeightDecay, stop, Optimiser, + ClipValue, ClipNorm include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 611edddb..466b7b6d 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -533,3 +533,31 @@ function apply!(o::WeightDecay, x, Δ) wd = o.wd @. Δ += wd * x end + +""" + ClipValue(thresh) + +Clip gradients when their absolute value exceeds `thresh`. +""" +mutable struct ClipValue{T} + thresh::T +end + +apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh) + +""" + ClipNorm(thresh) + +Clip gradients when their L2 norm exceeds `thresh`. +""" +mutable struct ClipNorm{T} + thresh::T +end + +function apply!(o::ClipNorm, x, Δ) + Δnrm = norm(Δ) + if Δnrm > o.thresh + rmul!(Δ, o.thresh / Δnrm) + end + return Δ +end \ No newline at end of file diff --git a/test/optimise.jl b/test/optimise.jl index ac131b96..b3a0250c 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -89,3 +89,15 @@ end @test decay_steps == ground_truth @test o.eta == o.clip end + +@testset "Clipping" begin + w = randn(10, 10) + loss(x) = sum(w * x) + θ = Params([w]) + x = 1000 * randn(10) + w̄ = gradient(() -> loss(x), θ)[w] + w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) + @test all(w̄_value .<= 1) + w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) + @test norm(w̄_norm) <= 1 +end \ No newline at end of file