diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 02de7b21..29916b13 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -6,8 +6,8 @@ tocb(fs::AbstractVector) = () -> foreach(call, fs) function train!(m, data, opt; cb = () -> ()) cb = tocb(cb) - @progress for (x, y) in data - back!(m(x, y)) + @progress for x in data + back!(m(x...)) opt() cb() end