2017-08-22 16:13:03 +00:00
|
|
|
module Optimise
|
|
|
|
|
2018-02-22 16:25:47 +00:00
|
|
|
export train!,
|
2018-07-03 10:15:43 +00:00
|
|
|
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
2018-06-08 11:24:41 +00:00
|
|
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
2017-08-22 21:25:18 +00:00
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
struct Param{T}
|
|
|
|
x::T
|
|
|
|
Δ::T
|
|
|
|
end
|
|
|
|
|
|
|
|
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
|
|
|
|
2017-08-22 21:25:18 +00:00
|
|
|
include("optimisers.jl")
|
2017-09-01 21:06:51 +00:00
|
|
|
include("interface.jl")
|
2017-08-24 10:42:29 +00:00
|
|
|
include("train.jl")
|
2017-08-22 16:13:03 +00:00
|
|
|
|
2017-08-31 16:36:18 +00:00
|
|
|
using Flux.Tracker: TrackedArray
|
|
|
|
|
2017-10-18 21:54:58 +00:00
|
|
|
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
2017-08-31 16:36:18 +00:00
|
|
|
|
2017-08-22 16:13:03 +00:00
|
|
|
end
|