simpler/nicer training loop
This commit is contained in:
parent
cd091ad005
commit
4cf43c0c41
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
|
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@. v = ρ*v - η*Δ
|
||||
@. Δ = -d
|
||||
|
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
|||
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
@ -321,7 +321,7 @@ end
|
|||
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
@. Δ += wd * data(x)
|
||||
end
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
using Juno
|
||||
import Flux.Tracker: data, grad, back!, update!
|
||||
import Flux.Tracker: Params, gradient, data, update!
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, apply!(opt, x, copy(data(x̄))))
|
||||
update!(x, -apply!(opt, x, data(x̄)))
|
||||
end
|
||||
|
||||
function _update_params!(opt, xs)
|
||||
function update!(opt, xs::Params, gs)
|
||||
for x in xs
|
||||
Δ = apply!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
update!(opt, x, gs[x])
|
||||
end
|
||||
end
|
||||
|
||||
# Added as an internal API but everyone started using it.
|
||||
function _update_params!(opt, xs)
|
||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||
for x in xs
|
||||
update!(opt, x, Tracker.grad(x))
|
||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
|
|||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
struct StopException <: Exception end
|
||||
"""
|
||||
stop()
|
||||
|
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
|
|||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
function train!(loss, ps, data, opt; cb = () -> ())
|
||||
ps = Params(ps)
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
_update_params!(opt, ps)
|
||||
gs = gradient(ps) do
|
||||
loss(d...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
|
|
@ -1,3 +1,13 @@
|
|||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
# In-place gradients
|
||||
|
||||
init_grad(x) = zero(x)
|
||||
|
@ -79,14 +89,14 @@ function gradient_(f, xs...)
|
|||
xs = param.(data.(xs))
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
@interrupts back!(l)
|
||||
extract_grad!.(xs)
|
||||
end
|
||||
|
||||
function gradient_(f, xs::Params)
|
||||
l = f()
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
@interrupts back!(l)
|
||||
gs = Grads()
|
||||
for x in xs
|
||||
gs[tracker(x)] = extract_grad!(x)
|
||||
|
|
|
@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ)
|
|||
return x
|
||||
end
|
||||
|
||||
function update!(x::AbstractArray, Δ)
|
||||
x .+= data(Δ)
|
||||
return x
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
|
|
|
@ -4,21 +4,15 @@ using Flux.Tracker
|
|||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
Momentum()]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt(0.001)
|
||||
if opt isa Descent || opt isa ADAGrad
|
||||
opt = Opt(0.1)
|
||||
end
|
||||
if opt isa ADADelta
|
||||
opt = Opt(0.9)
|
||||
end
|
||||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
θ = Params([w′])
|
||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue