restrict train! special casing

This commit is contained in:
CarloLucibello 2020-02-27 20:49:05 +01:00
parent b6c79b38b4
commit 487002878e
1 changed files with 1 additions and 1 deletions

View File

@ -79,7 +79,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
@progress for d in data @progress for d in data
try try
if d isa AbstractArray if d isa AbstractArray{<:Number}
gs = gradient(ps) do gs = gradient(ps) do
loss(d) loss(d)
end end