diff --git a/src/cost.jl b/src/cost.jl index e69de29b..62772894 100644 --- a/src/cost.jl +++ b/src/cost.jl @@ -0,0 +1,8 @@ +export mse, mse! + +function mse!(∇, pred, target) + map!(-, ∇, pred, target) + sumabs2(∇)/2 +end + +mse(pred, target) = mse(similar(pred), pred, target) diff --git a/src/utils.jl b/src/utils.jl index e5b8b815..b02bab53 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,3 +4,26 @@ const AArray = AbstractArray onehot(label, labels) = [i == label for i in labels] onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))] + +function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1) + i = 0 + ∇ = zeros(length(train[1][2])) + for _ in 1:epoch + for (x, y) in shuffle!(train) + i += 1 + err = mse!(∇, m(x), y) + back!(m, ∇) + i % batch == 0 && update!(m, η/batch) + end + @show accuracy(m, test) + end + return m +end + +function accuracy(m::Model, data) + correct = 0 + for (x, y) in data + onecold(m(x)) == onecold(y) && (correct += 1) + end + return correct/length(data) +end