2017-08-22 16:13:03 +00:00
|
|
|
module Optimise
|
|
|
|
|
2018-02-22 16:25:47 +00:00
|
|
|
export train!,
|
2018-07-03 10:15:43 +00:00
|
|
|
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
2018-08-21 19:03:30 +00:00
|
|
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
2017-08-22 21:25:18 +00:00
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
struct Param{T}
|
|
|
|
x::T
|
|
|
|
Δ::T
|
|
|
|
end
|
|
|
|
|
2018-08-11 12:53:47 +00:00
|
|
|
Param(x::AbstractArray) = Param(x, zero(x))
|
2017-09-27 20:11:21 +00:00
|
|
|
|
2017-08-22 21:25:18 +00:00
|
|
|
include("optimisers.jl")
|
2017-09-01 21:06:51 +00:00
|
|
|
include("interface.jl")
|
2017-08-24 10:42:29 +00:00
|
|
|
include("train.jl")
|
2017-08-22 16:13:03 +00:00
|
|
|
|
2017-08-31 16:36:18 +00:00
|
|
|
using Flux.Tracker: TrackedArray
|
|
|
|
|
2018-08-11 12:53:47 +00:00
|
|
|
Param(x::TrackedArray) = Param(x.data, x.grad)
|
|
|
|
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
2017-08-31 16:36:18 +00:00
|
|
|
|
2018-08-28 09:54:50 +00:00
|
|
|
end
|