customisable loss

This commit is contained in:
Mike J Innes 2017-04-28 17:14:21 +01:00
parent 63b328142a
commit ef4ec5be4b

View File

@ -7,7 +7,8 @@ initn(dims...) = randn(dims...)/100
tobatch(xs::Batch) = rawbatch(xs)
tobatch(xs) = tobatch(batchone(xs))
function train!(m, train, test = []; epoch = 1, η = 0.1)
function train!(m, train, test = [];
epoch = 1, η = 0.1, loss = mse)
i = 0
for e in 1:epoch
info("Epoch $e")
@ -16,7 +17,7 @@ function train!(m, train, test = []; epoch = 1, η = 0.1)
i += 1
= m(x)
any(isnan, ) && error("NaN")
Δ = back!(mse, 1, , y)
Δ = back!(loss, 1, , y)
back!(m, Δ, x)
update!(m, η)
i % 1000 == 0 && @show accuracy(m, test)