From 487002878ed530303cf9527e7cca0ea57b34d5b2 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 27 Feb 2020 20:49:05 +0100 Subject: [PATCH] restrict train! special casing --- src/optimise/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 34a98394..54b7f53a 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -79,7 +79,7 @@ function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) @progress for d in data try - if d isa AbstractArray + if d isa AbstractArray{<:Number} gs = gradient(ps) do loss(d) end