very basic step! implementation

This commit is contained in:
Mike J Innes 2019-03-12 12:21:12 +00:00
parent bde51aa5a6
commit 02c4ada05a
5 changed files with 78 additions and 13 deletions

View File

@ -281,7 +281,7 @@ version = "0.8.0"
[[Zygote]] [[Zygote]]
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"] deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
git-tree-sha1 = "7e99e2a6c5287fe658273fdd1723726ff8a211d9" git-tree-sha1 = "029cbc1d784d4a2e3f2d26d9b1631d89c2a0afb2"
repo-rev = "master" repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git" repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"

View File

@ -1,11 +1,12 @@
module Optimise module Optimise
export train!, export train!, step!,
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl") include("optimisers.jl")
include("update.jl")
include("train.jl") include("train.jl")
end end

View File

@ -4,8 +4,6 @@ using MacroTools: @forward
const ϵ = 1e-8 const ϵ = 1e-8
# TODO: should use weak refs
""" """
Descent(η) Descent(η)
@ -18,8 +16,8 @@ end
Descent() = Descent(0.1) Descent() = Descent(0.1)
function apply!(o::Descent, x, Δ) function apply(o::Descent, x, , state = nothing)
Δ .*= o.eta .* o.eta, state
end end
""" """

View File

@ -1,16 +1,26 @@
using Juno using Juno
import Zygote: Params, gradient import Zygote: Context, Params, _forward, gradient
function update!(opt, x, ) # Training step
update!(x, -apply!(opt, x, ))
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end end
function update!(opt, xs::Params, gs) function step!(f, opt, x...)
for x in xs cx = Context()
update!(opt, x, gs[x]) y, ∂f = _forward(cx, f, x...)
end losscheck(y)
= ∂f(1)[1] # TODO update f
= Globals(cx)
update!(opt, nothing, )
return y
end end
# Training loop
# Callback niceties # Callback niceties
call(f, xs...) = f(xs...) call(f, xs...) = f(xs...)
runall(f) = f runall(f) = f

56
src/optimise/update.jl Normal file
View File

@ -0,0 +1,56 @@
using Zygote: Context, globals
const Param{T<:Number} = Union{AbstractArray{T},T}
struct Globals{T}
gs::T
end
Globals(cx::Context) = Globals(globals(cx))
_apply(opt, x, , state) = apply(opt, x, , state)
_apply(opt, x, , ::Nothing) = apply(opt, x, )
# Immutable updates
function update(opt, x::Param, ::Param, state = nothing)
Δ, state = _apply(opt, x, , state)
return x .- Δ, state
end
# Mutable updates
# Figure out if we can do in-place
inplace(x, y) = false
inplace(x, y::Nothing) = true
inplace(x::AbstractArray, ::AbstractArray) = true
inplace(x, ::NamedTuple) = all(inplace(getfield(x, f), getfield(, f)) for f in fieldnames(typeof()))
function update!(opt, x::AbstractArray{<:Number}, ::AbstractArray, state = nothing)
Δ, state = _apply(opt, x, , state)
x .-= Δ
return state
end
function update!(opt, x, ::NamedTuple)
for f in fieldnames(typeof())
= getfield(, f)
=== nothing || update!(opt, getfield(x, f), )
end
end
setglobal!(mod::Module, name::Symbol, x) =
ccall(:jl_set_global, Cvoid, (Any, Any, Any), mod, name, x)
function update!(opt, ::Nothing, gs::Globals)
for (id, ) in gs.gs
x = getfield(id.mod, id.name)
if inplace(x, )
update!(opt, x, )
else
isconst(id.mod, id.name) && error("Can't update constant $id")
x, state = update(opt, x, )
setglobal!(id.mod, id.name, x)
end
end
end