diff --git a/src/utils.jl b/src/utils.jl index b02bab53..5f81462e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,7 +12,7 @@ function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1) for (x, y) in shuffle!(train) i += 1 err = mse!(∇, m(x), y) - back!(m, ∇) + back!(m, ∇, x) i % batch == 0 && update!(m, η/batch) end @show accuracy(m, test)