diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 123117a2..ae0f334c 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -7,14 +7,12 @@ function update!(x::AbstractArray, x̄) end function update!(opt, x, x̄) - if x̄ == nothing - x̄ = zeros(size(x)...) - end - update!(x, -apply!(opt, x, x̄)) + x .-= apply!(opt, x, x̄) end function update!(opt, xs::Params, gs) for x in xs + gs[x] == nothing && continue update!(opt, x, gs[x]) end end @@ -25,6 +23,7 @@ runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) struct StopException <: Exception end + """ stop()