pulled tracker from upstream
This commit is contained in:
parent
0b440f16ff
commit
d933f2079b
@ -19,8 +19,8 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
|||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
using .Optimise: @epochs
|
using .Optimise: @epochs
|
||||||
export SGD, Descent, ADAM, AdaMax, Momentum, Nesterov,
|
export Descent, ADAM, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad
|
RMSProp, update!
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
Descent, ADAM, Momentum, Nesterov, RMSProp
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
@ -27,14 +27,14 @@ Gradient descent with learning rate `η` and momentum `ρ`.
|
|||||||
mutable struct Momentum
|
mutable struct Momentum
|
||||||
eta::Float64
|
eta::Float64
|
||||||
rho::Float64
|
rho::Float64
|
||||||
velocity::ObjectIdDict
|
velocity::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
Momentum(η, ρ = 0.9) = Momentum(η, ρ, ObjectIdDict())
|
Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::Momentum, x, Δ)
|
function update!(o::Momentum, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = @get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
@. v = ρ * v - η * Δ
|
@. v = ρ * v - η * Δ
|
||||||
@. Δ = -v
|
@. Δ = -v
|
||||||
end
|
end
|
||||||
@ -47,14 +47,14 @@ Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
|||||||
mutable struct Nesterov
|
mutable struct Nesterov
|
||||||
eta::Float64
|
eta::Float64
|
||||||
rho::Float64
|
rho::Float64
|
||||||
velocity::ObjectIdDict
|
velocity::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, ObjectIdDict())
|
Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::Nesterov, x, Δ)
|
function update!(o::Nesterov, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = @get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||||
@. v = ρ*v - η*Δ
|
@. v = ρ*v - η*Δ
|
||||||
@. Δ = -d
|
@. Δ = -d
|
||||||
@ -70,14 +70,14 @@ choice for recurrent networks.
|
|||||||
mutable struct RMSProp
|
mutable struct RMSProp
|
||||||
eta::Float64
|
eta::Float64
|
||||||
rho::Float64
|
rho::Float64
|
||||||
acc::ObjectIdDict
|
acc::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, ObjectIdDict())
|
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::RMSProp, x, Δ)
|
function update!(o::RMSProp, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
acc = @get!(o.acc, x, zero(x))::typeof(x)
|
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
@ -90,14 +90,14 @@ end
|
|||||||
mutable struct ADAM
|
mutable struct ADAM
|
||||||
eta::Float64
|
eta::Float64
|
||||||
beta::Tuple{Float64,Float64}
|
beta::Tuple{Float64,Float64}
|
||||||
state::ObjectIdDict
|
state::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, ObjectIdDict())
|
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
||||||
|
|
||||||
function update!(o::ADAM, x, Δ)
|
function update!(o::ADAM, x, Δ)
|
||||||
η, β = o.eta, o.beta
|
η, β = o.eta, o.beta
|
||||||
mt, vt, βp = @get!(o.state, x, (zero(x), zero(x), β))
|
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||||
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
||||||
|
@ -1,23 +1,27 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
|
|
||||||
|
using MacroTools
|
||||||
|
using MacroTools: @q, @forward
|
||||||
|
|
||||||
import Base: ==
|
import Base: ==
|
||||||
|
|
||||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
|
||||||
|
|
||||||
tracker(x) = nothing
|
tracker(x) = nothing
|
||||||
|
|
||||||
istracked(x) = tracker(x) ≠ nothing
|
istracked(x) = tracker(x) ≠ nothing
|
||||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
|
||||||
grad(x) = grad(tracker(x))
|
grad(x) = grad(tracker(x))
|
||||||
grad(::Nothing) = nothing
|
grad(::Nothing) = nothing
|
||||||
|
data(x) = x
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
func::F
|
func::F
|
||||||
args::As
|
args::As
|
||||||
end
|
end
|
||||||
|
|
||||||
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
||||||
|
Call() = Call(nothing, ())
|
||||||
|
|
||||||
# When deserialising, the object_id changes
|
# When deserialising, the object_id changes
|
||||||
a::Call == b::Call = a.func == b.func && a.args == b.args
|
a::Call == b::Call = a.func == b.func && a.args == b.args
|
||||||
@ -28,33 +32,80 @@ mutable struct Tracked{T}
|
|||||||
ref::UInt32
|
ref::UInt32
|
||||||
f::Call
|
f::Call
|
||||||
isleaf::Bool
|
isleaf::Bool
|
||||||
data::T
|
|
||||||
grad::T
|
grad::T
|
||||||
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||||
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
|
||||||
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
|
||||||
|
|
||||||
track(f::Call, x) = Tracked(f, x)
|
|
||||||
track(f::Call) = track(f, f())
|
|
||||||
track(f, xs...) = track(Call(f, xs...))
|
|
||||||
|
|
||||||
istracked(x::Tracked) = true
|
istracked(x::Tracked) = true
|
||||||
isleaf(x::Tracked) = x.f == Call(nothing)
|
isleaf(x::Tracked) = x.f == Call()
|
||||||
data(x::Tracked) = x.data
|
|
||||||
grad(x::Tracked) = x.grad
|
grad(x::Tracked) = x.grad
|
||||||
|
|
||||||
|
track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||||
|
|
||||||
|
function _forward end
|
||||||
|
|
||||||
|
function track(f::F, xs...; kw...) where F
|
||||||
|
y, back = _forward(f, xs...; kw...)
|
||||||
|
track(Call(back, tracker.(xs)), y)
|
||||||
|
end
|
||||||
|
|
||||||
|
macro grad(ex)
|
||||||
|
@capture(shortdef(ex), (name_(args__) = body_) |
|
||||||
|
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
||||||
|
T == nothing && (T = [])
|
||||||
|
isexpr(name, :(::)) || (name = :(::typeof($name)))
|
||||||
|
insert!(args, 1+isexpr(args[1], :parameters) , name)
|
||||||
|
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||||
|
end
|
||||||
|
|
||||||
|
function update!(x, Δ)
|
||||||
|
x.data .+= data(Δ)
|
||||||
|
tracker(x).grad .= 0
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
|
include("idset.jl")
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("scalar.jl")
|
include("scalar.jl")
|
||||||
include("array.jl")
|
include("array.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
|
"""
|
||||||
|
hook(f, x) -> x′
|
||||||
|
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||||
|
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||||
|
the sign of the gradient applied to `x`."""
|
||||||
|
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||||
|
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||||
|
|
||||||
|
"""
|
||||||
|
checkpoint(f, args...)
|
||||||
|
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||||
|
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||||
|
backward pass. This can be used to save memory in larger models.
|
||||||
|
"""
|
||||||
|
checkpoint(f, args...) = track(checkpoint, f, args...)
|
||||||
|
|
||||||
|
@grad function checkpoint(f, args...)
|
||||||
|
data(f(args...)), function (Δ)
|
||||||
|
y, back = forward(f, args...)
|
||||||
|
(nothing, back(Δ)...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
nobacksies(f, x) = track(nobacksies, f, x)
|
||||||
|
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||||
|
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||||
|
|
||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
|
@grad identity(x) = data(x), Δ -> (Δ,)
|
||||||
|
param(x::TrackedReal) = track(identity, x)
|
||||||
|
param(x::TrackedArray) = track(identity, x)
|
||||||
|
|
||||||
import NNlib.cudata
|
import NNlib.cudata
|
||||||
import Adapt.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
|
@ -87,14 +87,11 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|||||||
|
|
||||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||||
|
|
||||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
S = size(xs)
|
S = size(xs)
|
||||||
|
|
||||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||||
@ -105,7 +102,6 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
|||||||
(nobacksies(:repeat, Δ′),)
|
(nobacksies(:repeat, Δ′),)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in [:vcat, :hcat]
|
for f in [:vcat, :hcat]
|
||||||
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||||
@eval begin
|
@eval begin
|
||||||
@ -361,7 +357,7 @@ end
|
|||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted, cat_nested
|
||||||
|
|
||||||
struct TrackedStyle <: BroadcastStyle end
|
struct TrackedStyle <: BroadcastStyle end
|
||||||
|
|
||||||
@ -385,6 +381,10 @@ end
|
|||||||
|
|
||||||
using Requires
|
using Requires
|
||||||
|
|
||||||
|
Base.Broadcast.cat_nested(t::Base.Broadcast.Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
|
||||||
|
Base.Broadcast.cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
|
||||||
|
Base.Broadcast.cat_nested() = ()
|
||||||
|
|
||||||
# https://github.com/FluxML/Flux.jl/issues/353
|
# https://github.com/FluxML/Flux.jl/issues/353
|
||||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||||
|
@ -3,16 +3,18 @@ using Flux.Tracker
|
|||||||
using Test
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
@testset for Opt in [Descent, Nesterov, RMSProp, ADAM, Momentum]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
delta = param(Tracker.similar(w′))
|
||||||
opt = Opt([w′])
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
for t=1:10^5
|
opt = Opt(0.1)
|
||||||
l = loss(rand(10))
|
for t=1:10^5
|
||||||
back!(l)
|
l = loss(rand(10))
|
||||||
opt()
|
back!(l)
|
||||||
end
|
update!(opt, w′.data, delta.data)
|
||||||
@test Flux.mse(w, w′) < 0.01
|
w′ .-= delta
|
||||||
|
end
|
||||||
|
@test Flux.mse(w, w′) < 0.01
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -23,7 +25,7 @@ end
|
|||||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
Iterators.repeated((), 100),
|
Iterators.repeated((), 100),
|
||||||
()->(),
|
()->(),
|
||||||
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||||
|
|
||||||
@test 3 < i < 50
|
@test 3 < i < 50
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user