Merge pull request #651 from FluxML/mji/dogfood

Refactor training loop
This commit is contained in:
Mike J Innes 2019-03-06 16:53:24 +00:00 committed by GitHub
commit 3a4c6274fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 114 additions and 90 deletions

View File

@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
v = get!(o.velocity, x, zero(x))::typeof(data(x))
@. v = ρ * v - η * Δ
@. Δ = -v
end
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
v = get!(o.velocity, x, zero(x))::typeof(data(x))
d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x)
acc = get!(o.acc, x, zero(x))::typeof(data(x))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ)
end
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function apply!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
@. acc += Δ^2
@. Δ *= η / (acc + ϵ)
end
@ -321,7 +321,7 @@ end
WeightDecay() = WeightDecay(0)
function apply!(o::WeightDecay, x, Δ)
function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
@. Δ += wd * data(x)
end

View File

@ -1,16 +1,23 @@
using Juno
import Flux.Tracker: data, grad, back!, update!
import Flux.Tracker: Params, gradient, data, update!
import Base.depwarn
function update!(opt, x, )
update!(x, apply!(opt, x, copy(data())))
update!(x, -apply!(opt, x, data()))
end
function _update_params!(opt, xs)
function update!(opt, xs::Params, gs)
for x in xs
Δ = apply!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
update!(opt, x, gs[x])
end
end
# Added as an internal API but everyone started using it.
function _update_params!(opt, xs)
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
for x in xs
update!(opt, x, Tracker.grad(x))
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
end
end
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
struct StopException <: Exception end
"""
stop()
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
opt = runall(opt)
@progress for d in data
try
l = loss(d...)
@interrupts back!(l)
_update_params!(opt, ps)
gs = gradient(ps) do
loss(d...)
end
update!(opt, ps, gs)
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break

View File

@ -62,6 +62,7 @@ macro grad(ex)
end
include("idset.jl")
include("params.jl")
include("back.jl")
include("numeric.jl")
include("lib/real.jl")

View File

@ -1,3 +1,15 @@
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
# In-place gradients
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)
@ -66,64 +78,34 @@ function back!(x, Δ; once = true)
return
end
function extract_grad!(x)
= copy(grad(x))
= nobacksies("Use `gradient(...; nest = true)` for nested derivatives", )
tracker(x).grad = zero_grad!(grad(x))
return
end
function gradient_(f, xs...)
xs = param.(data.(xs))
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
@interrupts back!(l)
extract_grad!.(xs)
end
function gradient_(f, xs::Params)
l = f()
losscheck(l)
@interrupts back!(l)
gs = Grads()
for x in xs
gs[tracker(x)] = extract_grad!(x)
end
return gs
end
# Out-of-place gradients
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
@ -182,8 +164,6 @@ end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)
# Jacobians and Hessians
import ..Flux

View File

@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ)
return x
end
function update!(x::AbstractArray, Δ)
x .+= data(Δ)
return x
end
# Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args

46
src/tracker/params.jl Normal file
View File

@ -0,0 +1,46 @@
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ

View File

@ -4,21 +4,15 @@ using Flux.Tracker
using Test
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5
l = loss(rand(10))
back!(l)
delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta
θ = Params([w])
θ̄ = gradient(() -> loss(rand(10)), θ)
Optimise.update!(opt, θ, θ̄)
end
@test Flux.mse(w, w) < 0.01
end