basic training loop
This commit is contained in:
parent
9ce0439943
commit
1526b13691
@ -1,10 +1,11 @@
|
||||
module Optimise
|
||||
|
||||
using ..Tracker: TrackedArray, data, grad
|
||||
using ..Tracker: TrackedArray, data, grad, back!
|
||||
|
||||
export sgd, update!, params
|
||||
export sgd, update!, params, train!
|
||||
|
||||
include("params.jl")
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
|
||||
end
|
||||
|
@ -9,5 +9,6 @@ function update!(o::SGD)
|
||||
for p in o.ps
|
||||
x, Δ = data(p), grad(p)
|
||||
x .-= Δ .* o.η
|
||||
Δ .= 0
|
||||
end
|
||||
end
|
||||
|
12
src/optimise/train.jl
Normal file
12
src/optimise/train.jl
Normal file
@ -0,0 +1,12 @@
|
||||
function train!(m, data, opt; epoch = 1)
|
||||
for e in 1:epoch
|
||||
epoch > 1 && info("Epoch $e")
|
||||
for (x, y) in data
|
||||
loss = m(x, y)
|
||||
@show loss
|
||||
back!(loss)
|
||||
update!(opt)
|
||||
end
|
||||
end
|
||||
return m
|
||||
end
|
Loading…
Reference in New Issue
Block a user