restrict train! special casing
This commit is contained in:
parent
b6c79b38b4
commit
487002878e
|
@ -79,7 +79,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
try
|
try
|
||||||
if d isa AbstractArray
|
if d isa AbstractArray{<:Number}
|
||||||
gs = gradient(ps) do
|
gs = gradient(ps) do
|
||||||
loss(d)
|
loss(d)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue