2017-08-22 16:13:03 +00:00
|
|
|
module Optimise
|
|
|
|
|
2017-09-01 21:06:51 +00:00
|
|
|
export update!, params, train!,
|
2018-04-02 19:57:22 +00:00
|
|
|
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
2017-08-22 21:25:18 +00:00
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
struct Param{T}
|
|
|
|
x::T
|
|
|
|
Δ::T
|
|
|
|
end
|
|
|
|
|
|
|
|
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
|
|
|
|
2017-08-22 21:25:18 +00:00
|
|
|
include("optimisers.jl")
|
2017-09-01 21:06:51 +00:00
|
|
|
include("interface.jl")
|
2017-08-24 10:42:29 +00:00
|
|
|
include("train.jl")
|
2017-08-22 16:13:03 +00:00
|
|
|
|
2017-08-31 16:36:18 +00:00
|
|
|
using Flux.Tracker: TrackedArray
|
|
|
|
|
2017-10-18 21:54:58 +00:00
|
|
|
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
2017-08-31 16:36:18 +00:00
|
|
|
|
2017-08-22 16:13:03 +00:00
|
|
|
end
|