opt refactor
This commit is contained in:
parent
7cd13789dd
commit
b95dae1868
@ -1,11 +1,15 @@
|
||||
module Optimise
|
||||
|
||||
using ..Tracker: TrackedArray, data, grad, back!
|
||||
|
||||
export sgd, update!, params, train!
|
||||
|
||||
include("params.jl")
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
params(ps, p::TrackedArray) = push!(ps, p)
|
||||
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ)
|
||||
|
||||
end
|
||||
|
@ -1,5 +1,5 @@
|
||||
struct SGD
|
||||
ps::Vector{Any}
|
||||
ps::Vector{Param}
|
||||
η::Float32
|
||||
end
|
||||
|
||||
@ -7,8 +7,7 @@ sgd(m, η) = SGD(params(m), η)
|
||||
|
||||
function update!(o::SGD)
|
||||
for p in o.ps
|
||||
x, Δ = data(p), grad(p)
|
||||
x .-= Δ .* o.η
|
||||
p.x .-= p.Δ .* o.η
|
||||
Δ .= 0
|
||||
end
|
||||
end
|
||||
|
@ -2,8 +2,6 @@ using DataFlow: OSet
|
||||
|
||||
children(x) = ()
|
||||
|
||||
params(ps, p::TrackedArray) = push!(ps, p)
|
||||
|
||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||
|
||||
function params(m)
|
||||
@ -11,3 +9,10 @@ function params(m)
|
||||
params(ps, m)
|
||||
return collect(ps)
|
||||
end
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δ::T
|
||||
end
|
||||
|
||||
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||
|
@ -1,3 +1,5 @@
|
||||
using Flux.Tracker: back!
|
||||
|
||||
function train!(m, data, opt; epoch = 1)
|
||||
for e in 1:epoch
|
||||
epoch > 1 && info("Epoch $e")
|
||||
|
Loading…
Reference in New Issue
Block a user