restrict train! special casing
This commit is contained in:
parent
b6c79b38b4
commit
487002878e
|
@ -79,7 +79,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||
cb = runall(cb)
|
||||
@progress for d in data
|
||||
try
|
||||
if d isa AbstractArray
|
||||
if d isa AbstractArray{<:Number}
|
||||
gs = gradient(ps) do
|
||||
loss(d)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue