From 0f2019eba5d2f2c61e90c5594f13954d9cff0f3f Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 12:57:03 +0100 Subject: [PATCH] compose tweaks --- src/optimise/optimisers.jl | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ae30445a..c3db9959 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -222,7 +222,7 @@ function update!(o::NADAM, x, Δ) end """ - ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. """ @@ -231,36 +231,22 @@ ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM( # Compose optimizers """ - `Compose(Compose(...), ...)` + Compose(a, b, c...) -Compose optimizers to support inbuilt or custom gradient updates while fitting the loss. - -Example:\n\n -`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))` +Combine several optimisers into one; each optimiser produces a modified gradient +that will be fed into the next, and this is finally applied to the parameter as +usual. """ mutable struct Compose os::Vector{Any} + Compose(o...) = Compose(Any[o...]) end -Compose(o...) = Compose(flattenCompose(o...)) - @forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! @forward Compose.os Base.iterate Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) -function flattenCompose(o...) - res = [] - for opt in o - if opt isa Compose - push!(res, flattenCompose(opt.os...)...) - else - push!(res, opt) - end - end - return res -end - function update!(o::Compose, x, Δ) for opt in o.os Δ = update!(opt, x, Δ)