commit
3a4c6274fa
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
|
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@. v = ρ*v - η*Δ
|
||||
@. Δ = -d
|
||||
|
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
|||
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
@ -321,7 +321,7 @@ end
|
|||
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
@. Δ += wd * data(x)
|
||||
end
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
using Juno
|
||||
import Flux.Tracker: data, grad, back!, update!
|
||||
import Flux.Tracker: Params, gradient, data, update!
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, apply!(opt, x, copy(data(x̄))))
|
||||
update!(x, -apply!(opt, x, data(x̄)))
|
||||
end
|
||||
|
||||
function _update_params!(opt, xs)
|
||||
function update!(opt, xs::Params, gs)
|
||||
for x in xs
|
||||
Δ = apply!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
update!(opt, x, gs[x])
|
||||
end
|
||||
end
|
||||
|
||||
# Added as an internal API but everyone started using it.
|
||||
function _update_params!(opt, xs)
|
||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||
for x in xs
|
||||
update!(opt, x, Tracker.grad(x))
|
||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
|
|||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
struct StopException <: Exception end
|
||||
"""
|
||||
stop()
|
||||
|
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
|
|||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
function train!(loss, ps, data, opt; cb = () -> ())
|
||||
ps = Params(ps)
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
_update_params!(opt, ps)
|
||||
gs = gradient(ps) do
|
||||
loss(d...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
|
|
@ -62,6 +62,7 @@ macro grad(ex)
|
|||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("params.jl")
|
||||
include("back.jl")
|
||||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
# In-place gradients
|
||||
|
||||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
@ -66,64 +78,34 @@ function back!(x, Δ; once = true)
|
|||
return
|
||||
end
|
||||
|
||||
function extract_grad!(x)
|
||||
x̄ = copy(grad(x))
|
||||
x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
|
||||
tracker(x).grad = zero_grad!(grad(x))
|
||||
return x̄
|
||||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(data.(xs))
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
||||
grad.(xs))
|
||||
@interrupts back!(l)
|
||||
extract_grad!.(xs)
|
||||
end
|
||||
|
||||
function gradient_(f, xs::Params)
|
||||
l = f()
|
||||
losscheck(l)
|
||||
@interrupts back!(l)
|
||||
gs = Grads()
|
||||
for x in xs
|
||||
gs[tracker(x)] = extract_grad!(x)
|
||||
end
|
||||
return gs
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
struct Grads
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
|
@ -182,8 +164,6 @@ end
|
|||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
||||
|
||||
# Jacobians and Hessians
|
||||
|
||||
import ..Flux
|
||||
|
|
|
@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ)
|
|||
return x
|
||||
end
|
||||
|
||||
function update!(x::AbstractArray, Δ)
|
||||
x .+= data(Δ)
|
||||
return x
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
struct Params
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
struct Grads
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|
@ -4,21 +4,15 @@ using Flux.Tracker
|
|||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
Momentum()]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt(0.001)
|
||||
if opt isa Descent || opt isa ADAGrad
|
||||
opt = Opt(0.1)
|
||||
end
|
||||
if opt isa ADADelta
|
||||
opt = Opt(0.9)
|
||||
end
|
||||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
θ = Params([w′])
|
||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue