basic training loop
This commit is contained in:
parent
9ce0439943
commit
1526b13691
@ -1,10 +1,11 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
using ..Tracker: TrackedArray, data, grad
|
using ..Tracker: TrackedArray, data, grad, back!
|
||||||
|
|
||||||
export sgd, update!, params
|
export sgd, update!, params, train!
|
||||||
|
|
||||||
include("params.jl")
|
include("params.jl")
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
|
include("train.jl")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -9,5 +9,6 @@ function update!(o::SGD)
|
|||||||
for p in o.ps
|
for p in o.ps
|
||||||
x, Δ = data(p), grad(p)
|
x, Δ = data(p), grad(p)
|
||||||
x .-= Δ .* o.η
|
x .-= Δ .* o.η
|
||||||
|
Δ .= 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
12
src/optimise/train.jl
Normal file
12
src/optimise/train.jl
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
function train!(m, data, opt; epoch = 1)
|
||||||
|
for e in 1:epoch
|
||||||
|
epoch > 1 && info("Epoch $e")
|
||||||
|
for (x, y) in data
|
||||||
|
loss = m(x, y)
|
||||||
|
@show loss
|
||||||
|
back!(loss)
|
||||||
|
update!(opt)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return m
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user