very basic step!
implementation
This commit is contained in:
parent
bde51aa5a6
commit
02c4ada05a
@ -281,7 +281,7 @@ version = "0.8.0"
|
|||||||
|
|
||||||
[[Zygote]]
|
[[Zygote]]
|
||||||
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
|
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
|
||||||
git-tree-sha1 = "7e99e2a6c5287fe658273fdd1723726ff8a211d9"
|
git-tree-sha1 = "029cbc1d784d4a2e3f2d26d9b1631d89c2a0afb2"
|
||||||
repo-rev = "master"
|
repo-rev = "master"
|
||||||
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export train!,
|
export train!, step!,
|
||||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
|
include("update.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -4,8 +4,6 @@ using MacroTools: @forward
|
|||||||
|
|
||||||
const ϵ = 1e-8
|
const ϵ = 1e-8
|
||||||
|
|
||||||
# TODO: should use weak refs
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Descent(η)
|
Descent(η)
|
||||||
|
|
||||||
@ -18,8 +16,8 @@ end
|
|||||||
|
|
||||||
Descent() = Descent(0.1)
|
Descent() = Descent(0.1)
|
||||||
|
|
||||||
function apply!(o::Descent, x, Δ)
|
function apply(o::Descent, x, x̄, state = nothing)
|
||||||
Δ .*= o.eta
|
x̄ .* o.eta, state
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -1,16 +1,26 @@
|
|||||||
using Juno
|
using Juno
|
||||||
import Zygote: Params, gradient
|
import Zygote: Context, Params, _forward, gradient
|
||||||
|
|
||||||
function update!(opt, x, x̄)
|
# Training step
|
||||||
update!(x, -apply!(opt, x, x̄))
|
|
||||||
|
function losscheck(x)
|
||||||
|
x isa Real || error("Function output is not scalar")
|
||||||
|
isinf(x) && error("Loss is infinite")
|
||||||
|
isnan(x) && error("Loss is NaN")
|
||||||
end
|
end
|
||||||
|
|
||||||
function update!(opt, xs::Params, gs)
|
function step!(f, opt, x...)
|
||||||
for x in xs
|
cx = Context()
|
||||||
update!(opt, x, gs[x])
|
y, ∂f = _forward(cx, f, x...)
|
||||||
end
|
losscheck(y)
|
||||||
|
f̄ = ∂f(1)[1] # TODO update f
|
||||||
|
ḡ = Globals(cx)
|
||||||
|
update!(opt, nothing, ḡ)
|
||||||
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
|
||||||
# Callback niceties
|
# Callback niceties
|
||||||
call(f, xs...) = f(xs...)
|
call(f, xs...) = f(xs...)
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
|
56
src/optimise/update.jl
Normal file
56
src/optimise/update.jl
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
using Zygote: Context, globals
|
||||||
|
|
||||||
|
const Param{T<:Number} = Union{AbstractArray{T},T}
|
||||||
|
|
||||||
|
struct Globals{T}
|
||||||
|
gs::T
|
||||||
|
end
|
||||||
|
|
||||||
|
Globals(cx::Context) = Globals(globals(cx))
|
||||||
|
|
||||||
|
_apply(opt, x, x̄, state) = apply(opt, x, x̄, state)
|
||||||
|
_apply(opt, x, x̄, ::Nothing) = apply(opt, x, x̄)
|
||||||
|
|
||||||
|
# Immutable updates
|
||||||
|
|
||||||
|
function update(opt, x::Param, x̄::Param, state = nothing)
|
||||||
|
Δ, state = _apply(opt, x, x̄, state)
|
||||||
|
return x .- Δ, state
|
||||||
|
end
|
||||||
|
|
||||||
|
# Mutable updates
|
||||||
|
|
||||||
|
# Figure out if we can do in-place
|
||||||
|
inplace(x, y) = false
|
||||||
|
inplace(x, y::Nothing) = true
|
||||||
|
inplace(x::AbstractArray, x̄::AbstractArray) = true
|
||||||
|
inplace(x, x̄::NamedTuple) = all(inplace(getfield(x, f), getfield(x̄, f)) for f in fieldnames(typeof(x̄)))
|
||||||
|
|
||||||
|
function update!(opt, x::AbstractArray{<:Number}, x̄::AbstractArray, state = nothing)
|
||||||
|
Δ, state = _apply(opt, x, x̄, state)
|
||||||
|
x .-= Δ
|
||||||
|
return state
|
||||||
|
end
|
||||||
|
|
||||||
|
function update!(opt, x, x̄::NamedTuple)
|
||||||
|
for f in fieldnames(typeof(x̄))
|
||||||
|
f̄ = getfield(x̄, f)
|
||||||
|
f̄ === nothing || update!(opt, getfield(x, f), f̄)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
setglobal!(mod::Module, name::Symbol, x) =
|
||||||
|
ccall(:jl_set_global, Cvoid, (Any, Any, Any), mod, name, x)
|
||||||
|
|
||||||
|
function update!(opt, ::Nothing, gs::Globals)
|
||||||
|
for (id, x̄) in gs.gs
|
||||||
|
x = getfield(id.mod, id.name)
|
||||||
|
if inplace(x, x̄)
|
||||||
|
update!(opt, x, x̄)
|
||||||
|
else
|
||||||
|
isconst(id.mod, id.name) && error("Can't update constant $id")
|
||||||
|
x′, state = update(opt, x, x̄)
|
||||||
|
setglobal!(id.mod, id.name, x′)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user