compose tweaks
This commit is contained in:
parent
9bc9771a8d
commit
0f2019eba5
@ -222,7 +222,7 @@ function update!(o::NADAM, x, Δ)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
|
|
||||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||||
"""
|
"""
|
||||||
@ -231,36 +231,22 @@ ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(
|
|||||||
# Compose optimizers
|
# Compose optimizers
|
||||||
|
|
||||||
"""
|
"""
|
||||||
`Compose(Compose(...), ...)`
|
Compose(a, b, c...)
|
||||||
|
|
||||||
Compose optimizers to support inbuilt or custom gradient updates while fitting the loss.
|
Combine several optimisers into one; each optimiser produces a modified gradient
|
||||||
|
that will be fed into the next, and this is finally applied to the parameter as
|
||||||
Example:\n\n
|
usual.
|
||||||
`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))`
|
|
||||||
"""
|
"""
|
||||||
mutable struct Compose
|
mutable struct Compose
|
||||||
os::Vector{Any}
|
os::Vector{Any}
|
||||||
|
Compose(o...) = Compose(Any[o...])
|
||||||
end
|
end
|
||||||
|
|
||||||
Compose(o...) = Compose(flattenCompose(o...))
|
|
||||||
|
|
||||||
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
||||||
@forward Compose.os Base.iterate
|
@forward Compose.os Base.iterate
|
||||||
|
|
||||||
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
|
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
|
||||||
|
|
||||||
function flattenCompose(o...)
|
|
||||||
res = []
|
|
||||||
for opt in o
|
|
||||||
if opt isa Compose
|
|
||||||
push!(res, flattenCompose(opt.os...)...)
|
|
||||||
else
|
|
||||||
push!(res, opt)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return res
|
|
||||||
end
|
|
||||||
|
|
||||||
function update!(o::Compose, x, Δ)
|
function update!(o::Compose, x, Δ)
|
||||||
for opt in o.os
|
for opt in o.os
|
||||||
Δ = update!(opt, x, Δ)
|
Δ = update!(opt, x, Δ)
|
||||||
|
Loading…
Reference in New Issue
Block a user