Flux.jl/src/optimise/Optimise.jl
2018-08-28 10:54:50 +01:00

24 lines
455 B
Julia

module Optimise
export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
struct Param{T}
x::T
Δ::T
end
Param(x::AbstractArray) = Param(x, zero(x))
include("optimisers.jl")
include("interface.jl")
include("train.jl")
using Flux.Tracker: TrackedArray
Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end