initial sketch
This commit is contained in:
parent
53be49b102
commit
a2d2d068aa
@ -19,8 +19,8 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
|||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
using .Optimise: @epochs
|
using .Optimise: @epochs
|
||||||
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
export SGD, Descent, ADAM, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
@ -1,23 +1,9 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
|
||||||
|
|
||||||
struct Param{T}
|
|
||||||
x::T
|
|
||||||
Δ::T
|
|
||||||
end
|
|
||||||
|
|
||||||
Param(x::AbstractArray) = Param(x, zero(x))
|
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("interface.jl")
|
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
|
||||||
using Flux.Tracker: TrackedArray
|
|
||||||
|
|
||||||
Param(x::TrackedArray) = Param(x.data, x.grad)
|
|
||||||
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
|
||||||
|
|
||||||
end
|
end
|
@ -1,110 +0,0 @@
|
|||||||
call(f, xs...) = f(xs...)
|
|
||||||
|
|
||||||
# note for optimisers: set to zero
|
|
||||||
# p.Δ at the end of the weights update
|
|
||||||
function optimiser(ps, fs...)
|
|
||||||
ps = [Param(p) for p in ps]
|
|
||||||
fs = map(ps) do p
|
|
||||||
os = map(f -> f(p), fs)
|
|
||||||
() -> foreach(call, os)
|
|
||||||
end
|
|
||||||
() -> foreach(call, fs)
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
SGD(params, η = 0.1; decay = 0)
|
|
||||||
|
|
||||||
Classic gradient descent optimiser with learning rate `η`.
|
|
||||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
|
||||||
|
|
||||||
Supports inverse decaying learning rate if the `decay` argument is provided.
|
|
||||||
"""
|
|
||||||
SGD(ps, η = 0.1; decay = 0) =
|
|
||||||
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
|
|
||||||
|
|
||||||
"""
|
|
||||||
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
|
|
||||||
|
|
||||||
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
|
|
||||||
"""
|
|
||||||
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
|
||||||
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
|
|
||||||
|
|
||||||
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
|
|
||||||
"""
|
|
||||||
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
|
||||||
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
|
||||||
|
|
||||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
|
||||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
|
||||||
choice for recurrent networks.
|
|
||||||
"""
|
|
||||||
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
|
||||||
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
|
||||||
"""
|
|
||||||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
|
||||||
"""
|
|
||||||
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
|
|
||||||
|
|
||||||
"""
|
|
||||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
|
||||||
the ∞-norm.
|
|
||||||
"""
|
|
||||||
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
|
||||||
|
|
||||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
|
||||||
Parameters don't need tuning.
|
|
||||||
"""
|
|
||||||
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
|
|
||||||
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
|
||||||
|
|
||||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
|
||||||
tuning.
|
|
||||||
"""
|
|
||||||
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
|
||||||
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
|
||||||
tuning.
|
|
||||||
"""
|
|
||||||
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
|
|
||||||
than learning rate don't need tuning.
|
|
||||||
"""
|
|
||||||
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
@ -1,130 +1,139 @@
|
|||||||
function descent(p::Param, η::Real)
|
using Flux
|
||||||
function ()
|
using Base: @get!
|
||||||
@. p.x -= η * p.Δ
|
|
||||||
@. p.Δ = 0
|
const ϵ = 1e-8
|
||||||
end
|
|
||||||
|
# TODO: should use weak refs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Descent(η)
|
||||||
|
|
||||||
|
Classic gradient descent optimiser with learning rate `η`.
|
||||||
|
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
||||||
|
"""
|
||||||
|
mutable struct Descent
|
||||||
|
eta::Float64
|
||||||
end
|
end
|
||||||
|
|
||||||
# Ref: https://arxiv.org/abs/1711.05101.pdf
|
function update!(o::Descent, x, Δ)
|
||||||
function descentweightdecay(p::Param, η::Real, γ::Real)
|
Δ .*= o.eta
|
||||||
function ()
|
|
||||||
@. p.x = p.x - η * (p.Δ + γ * p.x)
|
|
||||||
@. p.Δ = 0
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function momentum(p::Param, ρ, η)
|
"""
|
||||||
v = zero(p.x)
|
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
|
||||||
function ()
|
|
||||||
@. v = ρ * v - η * p.Δ
|
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||||
@. p.Δ = -v
|
"""
|
||||||
end
|
mutable struct Momentum
|
||||||
|
eta::Float64
|
||||||
|
rho::Float64
|
||||||
|
velocity::ObjectIdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
Momentum(η, ρ = 0.9) = Momentum(η, ρ, ObjectIdDict())
|
||||||
function nesterov(p::Param, ρ, η)
|
|
||||||
v = zero(p.x)
|
function update!(o::Momentum, x, Δ)
|
||||||
function ()
|
η, ρ = o.eta, o.rho
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
v = @get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
@. v = ρ*v - η*p.Δ
|
@. v = ρ * v - η * Δ
|
||||||
@. p.Δ = -d
|
@. Δ = -v
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
"""
|
||||||
acc = zero(p.x)
|
Nesterov(eta, ρ = 0.9)
|
||||||
function ()
|
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
"""
|
||||||
end
|
mutable struct Nesterov
|
||||||
|
eta::Float64
|
||||||
|
rho::Float64
|
||||||
|
velocity::ObjectIdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, ObjectIdDict())
|
||||||
acc = zero(p.x) .+ ϵ
|
|
||||||
function ()
|
function update!(o::Nesterov, x, Δ)
|
||||||
@. acc += p.Δ^2
|
η, ρ = o.eta, o.rho
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
v = @get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
end
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||||
|
@. v = ρ*v - η*Δ
|
||||||
|
@. Δ = -d
|
||||||
end
|
end
|
||||||
|
|
||||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
"""
|
||||||
acc = zero(p.x)
|
RMSProp(η = 0.001, ρ = 0.9)
|
||||||
Δacc = zero(p.x)
|
|
||||||
function ()
|
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
choice for recurrent networks.
|
||||||
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
|
"""
|
||||||
end
|
mutable struct RMSProp
|
||||||
|
eta::Float64
|
||||||
|
rho::Float64
|
||||||
|
acc::ObjectIdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, ObjectIdDict())
|
||||||
mt = zero(p.x)
|
|
||||||
vt = zero(p.x)
|
function update!(o::RMSProp, x, Δ)
|
||||||
β1p, β2p = β1, β2
|
η, ρ = o.eta, o.rho
|
||||||
function ()
|
acc = @get!(o.acc, x, zero(x))::typeof(x)
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
@. Δ *= η / (√acc + ϵ)
|
||||||
@. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η
|
|
||||||
β1p *= β1
|
|
||||||
β2p *= β2
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
"""
|
||||||
mt = zero(p.x)
|
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
ut = zero(p.x)
|
|
||||||
β1p = β1
|
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||||
function ()
|
"""
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
mutable struct ADAM
|
||||||
@. ut = max(β2 * ut, abs(p.Δ))
|
eta::Float64
|
||||||
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
|
beta::Tuple{Float64,Float64}
|
||||||
β1p *= β1
|
state::ObjectIdDict
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, ObjectIdDict())
|
||||||
mt = zero(p.x)
|
|
||||||
vt = zero(p.x) .+ ϵ
|
function update!(o::ADAM, x, Δ)
|
||||||
v̂t = zero(p.x) .+ ϵ
|
η, β = o.eta, o.beta
|
||||||
function ()
|
mt, vt, βp = @get!(o.state, x, (zero(x), zero(x), β))
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||||
@. v̂t = max.(v̂t, vt)
|
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
||||||
@. p.Δ = η * mt / √v̂t
|
o.state[x] = (mt, vt, βp .* β)
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
# """
|
||||||
mt = zero(p.x)
|
# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
vt = zero(p.x)
|
#
|
||||||
β1p, β2p = β1, β2
|
# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||||
function ()
|
# the ∞-norm.
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
# """
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
|
||||||
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η
|
|
||||||
β1p *= β1
|
|
||||||
β2p *= β2
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
# """
|
||||||
|
# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
||||||
|
#
|
||||||
|
# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||||
|
# Parameters don't need tuning.
|
||||||
|
# """
|
||||||
|
|
||||||
function expdecay(p::Param, γ::Real)
|
# """
|
||||||
if γ != 0
|
# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
||||||
return () -> p.Δ .+= γ .* p.x
|
#
|
||||||
else
|
# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||||
return () -> nothing
|
# tuning.
|
||||||
end
|
# """
|
||||||
end
|
|
||||||
|
|
||||||
function invdecay(p::Param, γ::Real)
|
# """
|
||||||
if γ != 0
|
# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
n = 0
|
#
|
||||||
return () -> begin
|
# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||||
p.Δ .*= 1 / (1 + γ * n)
|
# tuning.
|
||||||
n += 1
|
# """
|
||||||
end
|
|
||||||
else
|
# struct Optimiser
|
||||||
return () -> nothing
|
# os::Vector{Any}
|
||||||
end
|
# end
|
||||||
end
|
|
||||||
|
# TODO: decay
|
||||||
|
@ -1,7 +1,16 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!
|
using Flux.Tracker: data, grad, back!
|
||||||
import Base.depwarn
|
|
||||||
|
|
||||||
|
function update!(opt, xs)
|
||||||
|
for x in xs
|
||||||
|
x, Δ = data(x), grad(x)
|
||||||
|
update!(opt, x, Δ)
|
||||||
|
x .-= Δ
|
||||||
|
Δ .= 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# Callback niceties
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
|
@ -1,27 +1,23 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
|
|
||||||
using MacroTools
|
|
||||||
using MacroTools: @q, @forward
|
|
||||||
|
|
||||||
import Base: ==
|
import Base: ==
|
||||||
|
|
||||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
|
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||||
|
|
||||||
tracker(x) = nothing
|
tracker(x) = nothing
|
||||||
|
|
||||||
istracked(x) = tracker(x) ≠ nothing
|
istracked(x) = tracker(x) ≠ nothing
|
||||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||||
|
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||||
grad(x) = grad(tracker(x))
|
grad(x) = grad(tracker(x))
|
||||||
grad(::Nothing) = nothing
|
grad(::Nothing) = nothing
|
||||||
data(x) = x
|
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
func::F
|
func::F
|
||||||
args::As
|
args::As
|
||||||
end
|
end
|
||||||
|
|
||||||
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||||
Call() = Call(nothing, ())
|
|
||||||
|
|
||||||
# When deserialising, the object_id changes
|
# When deserialising, the object_id changes
|
||||||
a::Call == b::Call = a.func == b.func && a.args == b.args
|
a::Call == b::Call = a.func == b.func && a.args == b.args
|
||||||
@ -32,82 +28,33 @@ mutable struct Tracked{T}
|
|||||||
ref::UInt32
|
ref::UInt32
|
||||||
f::Call
|
f::Call
|
||||||
isleaf::Bool
|
isleaf::Bool
|
||||||
|
data::T
|
||||||
grad::T
|
grad::T
|
||||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
||||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
||||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||||
|
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
||||||
|
|
||||||
|
track(f::Call, x) = Tracked(f, x)
|
||||||
|
track(f::Call) = track(f, f())
|
||||||
|
track(f, xs...) = track(Call(f, xs...))
|
||||||
|
|
||||||
istracked(x::Tracked) = true
|
istracked(x::Tracked) = true
|
||||||
isleaf(x::Tracked) = x.f == Call()
|
isleaf(x::Tracked) = x.f == Call(nothing)
|
||||||
|
data(x::Tracked) = x.data
|
||||||
grad(x::Tracked) = x.grad
|
grad(x::Tracked) = x.grad
|
||||||
|
|
||||||
track(f::Call, x) = Tracked{typeof(x)}(f)
|
|
||||||
|
|
||||||
function _forward end
|
|
||||||
|
|
||||||
function track(f::F, xs...; kw...) where F
|
|
||||||
y, back = _forward(f, xs...; kw...)
|
|
||||||
track(Call(back, tracker.(xs)), y)
|
|
||||||
end
|
|
||||||
|
|
||||||
macro grad(ex)
|
|
||||||
@capture(shortdef(ex), (name_(args__) = body_) |
|
|
||||||
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
|
||||||
T == nothing && (T = [])
|
|
||||||
isexpr(name, :(::)) || (name = :(::typeof($name)))
|
|
||||||
insert!(args, 1+isexpr(args[1], :parameters) , name)
|
|
||||||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
|
||||||
end
|
|
||||||
|
|
||||||
function update!(x, Δ)
|
|
||||||
x.data .+= data(Δ)
|
|
||||||
tracker(x).grad .= 0
|
|
||||||
return x
|
|
||||||
end
|
|
||||||
|
|
||||||
include("idset.jl")
|
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("scalar.jl")
|
include("scalar.jl")
|
||||||
include("array.jl")
|
include("array.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
"""
|
|
||||||
hook(f, x) -> x′
|
|
||||||
|
|
||||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
|
||||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
|
||||||
the sign of the gradient applied to `x`."""
|
|
||||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
|
||||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
|
||||||
|
|
||||||
"""
|
|
||||||
checkpoint(f, args...)
|
|
||||||
|
|
||||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
|
||||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
|
||||||
backward pass. This can be used to save memory in larger models.
|
|
||||||
"""
|
|
||||||
checkpoint(f, args...) = track(checkpoint, f, args...)
|
|
||||||
|
|
||||||
@grad function checkpoint(f, args...)
|
|
||||||
data(f(args...)), function (Δ)
|
|
||||||
y, back = forward(f, args...)
|
|
||||||
(nothing, back(Δ)...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
nobacksies(f, x) = track(nobacksies, f, x)
|
|
||||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
|
||||||
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
|
||||||
|
|
||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
@grad identity(x) = data(x), Δ -> (Δ,)
|
|
||||||
param(x::TrackedReal) = track(identity, x)
|
|
||||||
param(x::TrackedArray) = track(identity, x)
|
|
||||||
|
|
||||||
import NNlib.cudata
|
import NNlib.cudata
|
||||||
import Adapt.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user