diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 8fe141f8..fd987861 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,10 +1,11 @@ module Optimise -using ..Tracker: TrackedArray, data, grad +using ..Tracker: TrackedArray, data, grad, back! -export sgd, update!, params +export sgd, update!, params, train! include("params.jl") include("optimisers.jl") +include("train.jl") end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index c333b3ae..8c9db7a8 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -9,5 +9,6 @@ function update!(o::SGD) for p in o.ps x, Δ = data(p), grad(p) x .-= Δ .* o.η + Δ .= 0 end end diff --git a/src/optimise/train.jl b/src/optimise/train.jl new file mode 100644 index 00000000..b8adcb4e --- /dev/null +++ b/src/optimise/train.jl @@ -0,0 +1,12 @@ +function train!(m, data, opt; epoch = 1) + for e in 1:epoch + epoch > 1 && info("Epoch $e") + for (x, y) in data + loss = m(x, y) + @show loss + back!(loss) + update!(opt) + end + end + return m +end