diff --git a/src/utils.jl b/src/utils.jl index e9d0eddc..5f1e2807 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ const AArray = AbstractArray initn(dims...) = randn(dims...)/100 tobatch(xs::Batch) = rawbatch(xs) -tobatch(xs) = unsqueeze(xs) +tobatch(xs) = tobatch(batchone(xs)) function train!(m, train, test = []; epoch = 1, η = 0.1) i = 0