diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 0a91e978..618ecf66 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -4,10 +4,10 @@ using Flux.Tracker: back! tocb(f) = f tocb(fs::AbstractVector) = () -> foreach(call, fs) -function train!(m, data, opt; cb = () -> ()) +function train!(loss, data, opt; cb = () -> ()) cb = tocb(cb) @progress for x in data - l = m(x...) + l = loss(x...) isinf(l.data[]) && error("Loss is Inf") isnan(l.data[]) && error("Loss is NaN") back!(l)