diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 34a98394..54b7f53a 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -79,7 +79,7 @@ function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) @progress for d in data try - if d isa AbstractArray + if d isa AbstractArray{<:Number} gs = gradient(ps) do loss(d) end