very basic step!
implementation
This commit is contained in:
parent
bde51aa5a6
commit
02c4ada05a
@ -281,7 +281,7 @@ version = "0.8.0"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
|
||||
git-tree-sha1 = "7e99e2a6c5287fe658273fdd1723726ff8a211d9"
|
||||
git-tree-sha1 = "029cbc1d784d4a2e3f2d26d9b1631d89c2a0afb2"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
@ -1,11 +1,12 @@
|
||||
module Optimise
|
||||
|
||||
export train!,
|
||||
export train!, step!,
|
||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||
|
||||
include("optimisers.jl")
|
||||
include("update.jl")
|
||||
include("train.jl")
|
||||
|
||||
end
|
||||
|
@ -4,8 +4,6 @@ using MacroTools: @forward
|
||||
|
||||
const ϵ = 1e-8
|
||||
|
||||
# TODO: should use weak refs
|
||||
|
||||
"""
|
||||
Descent(η)
|
||||
|
||||
@ -18,8 +16,8 @@ end
|
||||
|
||||
Descent() = Descent(0.1)
|
||||
|
||||
function apply!(o::Descent, x, Δ)
|
||||
Δ .*= o.eta
|
||||
function apply(o::Descent, x, x̄, state = nothing)
|
||||
x̄ .* o.eta, state
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -1,16 +1,26 @@
|
||||
using Juno
|
||||
import Zygote: Params, gradient
|
||||
import Zygote: Context, Params, _forward, gradient
|
||||
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, -apply!(opt, x, x̄))
|
||||
# Training step
|
||||
|
||||
function losscheck(x)
|
||||
x isa Real || error("Function output is not scalar")
|
||||
isinf(x) && error("Loss is infinite")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
end
|
||||
|
||||
function update!(opt, xs::Params, gs)
|
||||
for x in xs
|
||||
update!(opt, x, gs[x])
|
||||
end
|
||||
function step!(f, opt, x...)
|
||||
cx = Context()
|
||||
y, ∂f = _forward(cx, f, x...)
|
||||
losscheck(y)
|
||||
f̄ = ∂f(1)[1] # TODO update f
|
||||
ḡ = Globals(cx)
|
||||
update!(opt, nothing, ḡ)
|
||||
return y
|
||||
end
|
||||
|
||||
# Training loop
|
||||
|
||||
# Callback niceties
|
||||
call(f, xs...) = f(xs...)
|
||||
runall(f) = f
|
||||
|
56
src/optimise/update.jl
Normal file
56
src/optimise/update.jl
Normal file
@ -0,0 +1,56 @@
|
||||
using Zygote: Context, globals
|
||||
|
||||
const Param{T<:Number} = Union{AbstractArray{T},T}
|
||||
|
||||
struct Globals{T}
|
||||
gs::T
|
||||
end
|
||||
|
||||
Globals(cx::Context) = Globals(globals(cx))
|
||||
|
||||
_apply(opt, x, x̄, state) = apply(opt, x, x̄, state)
|
||||
_apply(opt, x, x̄, ::Nothing) = apply(opt, x, x̄)
|
||||
|
||||
# Immutable updates
|
||||
|
||||
function update(opt, x::Param, x̄::Param, state = nothing)
|
||||
Δ, state = _apply(opt, x, x̄, state)
|
||||
return x .- Δ, state
|
||||
end
|
||||
|
||||
# Mutable updates
|
||||
|
||||
# Figure out if we can do in-place
|
||||
inplace(x, y) = false
|
||||
inplace(x, y::Nothing) = true
|
||||
inplace(x::AbstractArray, x̄::AbstractArray) = true
|
||||
inplace(x, x̄::NamedTuple) = all(inplace(getfield(x, f), getfield(x̄, f)) for f in fieldnames(typeof(x̄)))
|
||||
|
||||
function update!(opt, x::AbstractArray{<:Number}, x̄::AbstractArray, state = nothing)
|
||||
Δ, state = _apply(opt, x, x̄, state)
|
||||
x .-= Δ
|
||||
return state
|
||||
end
|
||||
|
||||
function update!(opt, x, x̄::NamedTuple)
|
||||
for f in fieldnames(typeof(x̄))
|
||||
f̄ = getfield(x̄, f)
|
||||
f̄ === nothing || update!(opt, getfield(x, f), f̄)
|
||||
end
|
||||
end
|
||||
|
||||
setglobal!(mod::Module, name::Symbol, x) =
|
||||
ccall(:jl_set_global, Cvoid, (Any, Any, Any), mod, name, x)
|
||||
|
||||
function update!(opt, ::Nothing, gs::Globals)
|
||||
for (id, x̄) in gs.gs
|
||||
x = getfield(id.mod, id.name)
|
||||
if inplace(x, x̄)
|
||||
update!(opt, x, x̄)
|
||||
else
|
||||
isconst(id.mod, id.name) && error("Can't update constant $id")
|
||||
x′, state = update(opt, x, x̄)
|
||||
setglobal!(id.mod, id.name, x′)
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue
Block a user