Flux.jl/src/optimise/Optimise.jl
2017-12-04 09:17:05 +01:00

22 lines
401 B
Julia

module Optimise
export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T}
x::T
Δ::T
end
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
include("optimisers.jl")
include("interface.jl")
include("train.jl")
using Flux.Tracker: TrackedArray
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end