From ef4ec5be4b3b3eded169027567da89e7820f586c Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 28 Apr 2017 17:14:21 +0100 Subject: [PATCH] customisable loss --- src/utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 35f7343d..c61e2da2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -7,7 +7,8 @@ initn(dims...) = randn(dims...)/100 tobatch(xs::Batch) = rawbatch(xs) tobatch(xs) = tobatch(batchone(xs)) -function train!(m, train, test = []; epoch = 1, η = 0.1) +function train!(m, train, test = []; + epoch = 1, η = 0.1, loss = mse) i = 0 for e in 1:epoch info("Epoch $e") @@ -16,7 +17,7 @@ function train!(m, train, test = []; epoch = 1, η = 0.1) i += 1 ŷ = m(x) any(isnan, ŷ) && error("NaN") - Δ = back!(mse, 1, ŷ, y) + Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) update!(m, η) i % 1000 == 0 && @show accuracy(m, test)