From 759fe9df2fb0a4665052383fae1b0fd8978a2f52 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 26 Feb 2020 20:27:39 +0100 Subject: [PATCH] update docs and export update! --- docs/src/training/optimisers.md | 3 ++- src/optimise/Optimise.jl | 2 +- src/optimise/train.jl | 17 +++++++++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 5e8b95de..37288b5d 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -21,7 +21,7 @@ grads = gradient(() -> loss(x, y), θ) We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: ```julia -using Flux: update! +using Flux.Optimise: update! η = 0.1 # Learning Rate for p in (W, b) @@ -46,6 +46,7 @@ An optimiser `update!` accepts a parameter and a gradient, and updates the param All optimisers return an object that, when passed to `train!`, will update the parameters passed to it. ```@docs +Flux.Optimise.update! Descent Momentum Nesterov diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 68c18a6f..28a1849d 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,6 +1,6 @@ module Optimise -export train!, +export train!, update!, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, InvDecay, ExpDecay, WeightDecay, stop, Optimiser diff --git a/src/optimise/train.jl b/src/optimise/train.jl index ae0f334c..59404a42 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,9 +1,22 @@ using Juno import Zygote: Params, gradient + +""" + update!(opt, p, g) + update!(opt, ps::Params, gs) + +Perform an update step of the parameters `ps` (or the single parameter `p`) +according to optimizer `opt` and the gradients `gs` (the gradient `g`). + +As a result, the parameters are mutated and the optimizer's internal state may change. + + update!(x, x̄) + +Update the array `x` according to `x .-= x̄`. +""" function update!(x::AbstractArray, x̄) - x .+= x̄ - return x + x .-= x̄ end function update!(opt, x, x̄)