diff --git a/src/utils.jl b/src/utils.jl index 35f7343d..c61e2da2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -7,7 +7,8 @@ initn(dims...) = randn(dims...)/100 tobatch(xs::Batch) = rawbatch(xs) tobatch(xs) = tobatch(batchone(xs)) -function train!(m, train, test = []; epoch = 1, η = 0.1) +function train!(m, train, test = []; + epoch = 1, η = 0.1, loss = mse) i = 0 for e in 1:epoch info("Epoch $e") @@ -16,7 +17,7 @@ function train!(m, train, test = []; epoch = 1, η = 0.1) i += 1 ŷ = m(x) any(isnan, ŷ) && error("NaN") - Δ = back!(mse, 1, ŷ, y) + Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) update!(m, η) i % 1000 == 0 && @show accuracy(m, test)