compose tweaks
This commit is contained in:
parent
9bc9771a8d
commit
bfe85e65f1
@ -222,7 +222,7 @@ function update!(o::NADAM, x, Δ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
@ -231,36 +231,22 @@ ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(
|
||||
# Compose optimizers
|
||||
|
||||
"""
|
||||
`Compose(Compose(...), ...)`
|
||||
Compose(a, b, c...)
|
||||
|
||||
Compose optimizers to support inbuilt or custom gradient updates while fitting the loss.
|
||||
|
||||
Example:\n\n
|
||||
`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))`
|
||||
Combine several optimisers into one; each optimiser produces a modified gradient
|
||||
that will be fed into the next, and this is finally applied to the parameter as
|
||||
usual.
|
||||
"""
|
||||
mutable struct Compose
|
||||
os::Vector{Any}
|
||||
Compose(o...) = new(Any[o...])
|
||||
end
|
||||
|
||||
Compose(o...) = Compose(flattenCompose(o...))
|
||||
|
||||
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
||||
@forward Compose.os Base.iterate
|
||||
|
||||
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
|
||||
|
||||
function flattenCompose(o...)
|
||||
res = []
|
||||
for opt in o
|
||||
if opt isa Compose
|
||||
push!(res, flattenCompose(opt.os...)...)
|
||||
else
|
||||
push!(res, opt)
|
||||
end
|
||||
end
|
||||
return res
|
||||
end
|
||||
|
||||
function update!(o::Compose, x, Δ)
|
||||
for opt in o.os
|
||||
Δ = update!(opt, x, Δ)
|
||||
|
Loading…
Reference in New Issue
Block a user