From 02c4ada05a180d162011810e754339b6009e2100 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 12 Mar 2019 12:21:12 +0000 Subject: [PATCH] very basic `step!` implementation --- Manifest.toml | 2 +- src/optimise/Optimise.jl | 3 +- src/optimise/optimisers.jl | 6 ++-- src/optimise/train.jl | 24 +++++++++++----- src/optimise/update.jl | 56 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 78 insertions(+), 13 deletions(-) create mode 100644 src/optimise/update.jl diff --git a/Manifest.toml b/Manifest.toml index 8f362303..b297735a 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -281,7 +281,7 @@ version = "0.8.0" [[Zygote]] deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"] -git-tree-sha1 = "7e99e2a6c5287fe658273fdd1723726ff8a211d9" +git-tree-sha1 = "029cbc1d784d4a2e3f2d26d9b1631d89c2a0afb2" repo-rev = "master" repo-url = "https://github.com/FluxML/Zygote.jl.git" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e98c5afc..f16b5342 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,11 +1,12 @@ module Optimise -export train!, +export train!, step!, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, InvDecay, ExpDecay, WeightDecay, stop, Optimiser include("optimisers.jl") +include("update.jl") include("train.jl") end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index d151cf32..ae112746 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -4,8 +4,6 @@ using MacroTools: @forward const ϵ = 1e-8 -# TODO: should use weak refs - """ Descent(η) @@ -18,8 +16,8 @@ end Descent() = Descent(0.1) -function apply!(o::Descent, x, Δ) - Δ .*= o.eta +function apply(o::Descent, x, x̄, state = nothing) + x̄ .* o.eta, state end """ diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 6cc4efcf..7dec59a3 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,16 +1,26 @@ using Juno -import Zygote: Params, gradient +import Zygote: Context, Params, _forward, gradient -function update!(opt, x, x̄) - update!(x, -apply!(opt, x, x̄)) +# Training step + +function losscheck(x) + x isa Real || error("Function output is not scalar") + isinf(x) && error("Loss is infinite") + isnan(x) && error("Loss is NaN") end -function update!(opt, xs::Params, gs) - for x in xs - update!(opt, x, gs[x]) - end +function step!(f, opt, x...) + cx = Context() + y, ∂f = _forward(cx, f, x...) + losscheck(y) + f̄ = ∂f(1)[1] # TODO update f + ḡ = Globals(cx) + update!(opt, nothing, ḡ) + return y end +# Training loop + # Callback niceties call(f, xs...) = f(xs...) runall(f) = f diff --git a/src/optimise/update.jl b/src/optimise/update.jl new file mode 100644 index 00000000..81aa7e81 --- /dev/null +++ b/src/optimise/update.jl @@ -0,0 +1,56 @@ +using Zygote: Context, globals + +const Param{T<:Number} = Union{AbstractArray{T},T} + +struct Globals{T} + gs::T +end + +Globals(cx::Context) = Globals(globals(cx)) + +_apply(opt, x, x̄, state) = apply(opt, x, x̄, state) +_apply(opt, x, x̄, ::Nothing) = apply(opt, x, x̄) + +# Immutable updates + +function update(opt, x::Param, x̄::Param, state = nothing) + Δ, state = _apply(opt, x, x̄, state) + return x .- Δ, state +end + +# Mutable updates + +# Figure out if we can do in-place +inplace(x, y) = false +inplace(x, y::Nothing) = true +inplace(x::AbstractArray, x̄::AbstractArray) = true +inplace(x, x̄::NamedTuple) = all(inplace(getfield(x, f), getfield(x̄, f)) for f in fieldnames(typeof(x̄))) + +function update!(opt, x::AbstractArray{<:Number}, x̄::AbstractArray, state = nothing) + Δ, state = _apply(opt, x, x̄, state) + x .-= Δ + return state +end + +function update!(opt, x, x̄::NamedTuple) + for f in fieldnames(typeof(x̄)) + f̄ = getfield(x̄, f) + f̄ === nothing || update!(opt, getfield(x, f), f̄) + end +end + +setglobal!(mod::Module, name::Symbol, x) = + ccall(:jl_set_global, Cvoid, (Any, Any, Any), mod, name, x) + +function update!(opt, ::Nothing, gs::Globals) + for (id, x̄) in gs.gs + x = getfield(id.mod, id.name) + if inplace(x, x̄) + update!(opt, x, x̄) + else + isconst(id.mod, id.name) && error("Can't update constant $id") + x′, state = update(opt, x, x̄) + setglobal!(id.mod, id.name, x′) + end + end +end