very basic step! implementation

This commit is contained in:
Mike J Innes 2019-03-12 12:21:12 +00:00
parent bde51aa5a6
commit 02c4ada05a
5 changed files with 78 additions and 13 deletions

View File

@ -281,7 +281,7 @@ version = "0.8.0"
[[Zygote]]
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
git-tree-sha1 = "7e99e2a6c5287fe658273fdd1723726ff8a211d9"
git-tree-sha1 = "029cbc1d784d4a2e3f2d26d9b1631d89c2a0afb2"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"

View File

@ -1,11 +1,12 @@
module Optimise
export train!,
export train!, step!,
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl")
include("update.jl")
include("train.jl")
end

View File

@ -4,8 +4,6 @@ using MacroTools: @forward
const ϵ = 1e-8
# TODO: should use weak refs
"""
Descent(η)
@ -18,8 +16,8 @@ end
Descent() = Descent(0.1)
function apply!(o::Descent, x, Δ)
Δ .*= o.eta
function apply(o::Descent, x, , state = nothing)
.* o.eta, state
end
"""

View File

@ -1,16 +1,26 @@
using Juno
import Zygote: Params, gradient
import Zygote: Context, Params, _forward, gradient
function update!(opt, x, )
update!(x, -apply!(opt, x, ))
# Training step
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
function update!(opt, xs::Params, gs)
for x in xs
update!(opt, x, gs[x])
end
function step!(f, opt, x...)
cx = Context()
y, ∂f = _forward(cx, f, x...)
losscheck(y)
= ∂f(1)[1] # TODO update f
= Globals(cx)
update!(opt, nothing, )
return y
end
# Training loop
# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f

56
src/optimise/update.jl Normal file
View File

@ -0,0 +1,56 @@
using Zygote: Context, globals
const Param{T<:Number} = Union{AbstractArray{T},T}
struct Globals{T}
gs::T
end
Globals(cx::Context) = Globals(globals(cx))
_apply(opt, x, , state) = apply(opt, x, , state)
_apply(opt, x, , ::Nothing) = apply(opt, x, )
# Immutable updates
function update(opt, x::Param, ::Param, state = nothing)
Δ, state = _apply(opt, x, , state)
return x .- Δ, state
end
# Mutable updates
# Figure out if we can do in-place
inplace(x, y) = false
inplace(x, y::Nothing) = true
inplace(x::AbstractArray, ::AbstractArray) = true
inplace(x, ::NamedTuple) = all(inplace(getfield(x, f), getfield(, f)) for f in fieldnames(typeof()))
function update!(opt, x::AbstractArray{<:Number}, ::AbstractArray, state = nothing)
Δ, state = _apply(opt, x, , state)
x .-= Δ
return state
end
function update!(opt, x, ::NamedTuple)
for f in fieldnames(typeof())
= getfield(, f)
=== nothing || update!(opt, getfield(x, f), )
end
end
setglobal!(mod::Module, name::Symbol, x) =
ccall(:jl_set_global, Cvoid, (Any, Any, Any), mod, name, x)
function update!(opt, ::Nothing, gs::Globals)
for (id, ) in gs.gs
x = getfield(id.mod, id.name)
if inplace(x, )
update!(opt, x, )
else
isconst(id.mod, id.name) && error("Can't update constant $id")
x, state = update(opt, x, )
setglobal!(id.mod, id.name, x)
end
end
end