changes as requested
This commit is contained in:
parent
1af7a53e1f
commit
ed044e2df7
@ -21,7 +21,7 @@ include("optimise/Optimise.jl")
|
|||||||
using .Optimise
|
using .Optimise
|
||||||
using .Optimise: @epochs
|
using .Optimise: @epochs
|
||||||
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
@ -2,7 +2,7 @@ module Optimise
|
|||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
|
@ -55,14 +55,13 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
|
||||||
@interrupts back!(l)
|
|
||||||
opt()
|
|
||||||
try
|
try
|
||||||
cb()
|
l = loss(d...)
|
||||||
|
@interrupts back!(l)
|
||||||
|
opt()
|
||||||
|
cb() == :stop && break
|
||||||
catch ex
|
catch ex
|
||||||
if ex isa StopException
|
if ex isa StopException
|
||||||
@info "Stop condition met"
|
|
||||||
break
|
break
|
||||||
else
|
else
|
||||||
rethrow(ex)
|
rethrow(ex)
|
||||||
|
Loading…
Reference in New Issue
Block a user