Flux.jl/src/optimise/Optimise.jl

24 lines
455 B
Julia
Raw Normal View History

2017-08-22 16:13:03 +00:00
module Optimise
2018-02-22 16:25:47 +00:00
export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
2018-08-21 19:03:30 +00:00
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
2017-08-22 21:25:18 +00:00
2017-09-27 20:11:21 +00:00
struct Param{T}
x::T
Δ::T
end
2018-08-11 12:53:47 +00:00
Param(x::AbstractArray) = Param(x, zero(x))
2017-09-27 20:11:21 +00:00
2017-08-22 21:25:18 +00:00
include("optimisers.jl")
2017-09-01 21:06:51 +00:00
include("interface.jl")
2017-08-24 10:42:29 +00:00
include("train.jl")
2017-08-22 16:13:03 +00:00
2017-08-31 16:36:18 +00:00
using Flux.Tracker: TrackedArray
2018-08-11 12:53:47 +00:00
Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
2017-08-31 16:36:18 +00:00
2018-08-28 09:54:50 +00:00
end