compose tweaks

This commit is contained in:
Mike Innes 2018-10-05 12:57:03 +01:00
parent 9bc9771a8d
commit 0f2019eba5

View File

@ -222,7 +222,7 @@ function update!(o::NADAM, x, Δ)
end end
""" """
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
""" """
@ -231,36 +231,22 @@ ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(
# Compose optimizers # Compose optimizers
""" """
`Compose(Compose(...), ...)` Compose(a, b, c...)
Compose optimizers to support inbuilt or custom gradient updates while fitting the loss. Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
Example:\n\n usual.
`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))`
""" """
mutable struct Compose mutable struct Compose
os::Vector{Any} os::Vector{Any}
Compose(o...) = Compose(Any[o...])
end end
Compose(o...) = Compose(flattenCompose(o...))
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! @forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Compose.os Base.iterate @forward Compose.os Base.iterate
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
function flattenCompose(o...)
res = []
for opt in o
if opt isa Compose
push!(res, flattenCompose(opt.os...)...)
else
push!(res, opt)
end
end
return res
end
function update!(o::Compose, x, Δ) function update!(o::Compose, x, Δ)
for opt in o.os for opt in o.os
Δ = update!(opt, x, Δ) Δ = update!(opt, x, Δ)