basic training loop

This commit is contained in:
Mike J Innes 2017-08-24 11:42:29 +01:00
parent 9ce0439943
commit 1526b13691
3 changed files with 16 additions and 2 deletions

View File

@ -1,10 +1,11 @@
module Optimise
using ..Tracker: TrackedArray, data, grad
using ..Tracker: TrackedArray, data, grad, back!
export sgd, update!, params
export sgd, update!, params, train!
include("params.jl")
include("optimisers.jl")
include("train.jl")
end

View File

@ -9,5 +9,6 @@ function update!(o::SGD)
for p in o.ps
x, Δ = data(p), grad(p)
x .-= Δ .* o.η
Δ .= 0
end
end

12
src/optimise/train.jl Normal file
View File

@ -0,0 +1,12 @@
function train!(m, data, opt; epoch = 1)
for e in 1:epoch
epoch > 1 && info("Epoch $e")
for (x, y) in data
loss = m(x, y)
@show loss
back!(loss)
update!(opt)
end
end
return m
end