opt refactor
This commit is contained in:
parent
7cd13789dd
commit
b95dae1868
@ -1,11 +1,15 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
using ..Tracker: TrackedArray, data, grad, back!
|
|
||||||
|
|
||||||
export sgd, update!, params, train!
|
export sgd, update!, params, train!
|
||||||
|
|
||||||
include("params.jl")
|
include("params.jl")
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
|
||||||
|
using Flux.Tracker: TrackedArray
|
||||||
|
|
||||||
|
params(ps, p::TrackedArray) = push!(ps, p)
|
||||||
|
|
||||||
|
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
struct SGD
|
struct SGD
|
||||||
ps::Vector{Any}
|
ps::Vector{Param}
|
||||||
η::Float32
|
η::Float32
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -7,8 +7,7 @@ sgd(m, η) = SGD(params(m), η)
|
|||||||
|
|
||||||
function update!(o::SGD)
|
function update!(o::SGD)
|
||||||
for p in o.ps
|
for p in o.ps
|
||||||
x, Δ = data(p), grad(p)
|
p.x .-= p.Δ .* o.η
|
||||||
x .-= Δ .* o.η
|
|
||||||
Δ .= 0
|
Δ .= 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -2,8 +2,6 @@ using DataFlow: OSet
|
|||||||
|
|
||||||
children(x) = ()
|
children(x) = ()
|
||||||
|
|
||||||
params(ps, p::TrackedArray) = push!(ps, p)
|
|
||||||
|
|
||||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||||
|
|
||||||
function params(m)
|
function params(m)
|
||||||
@ -11,3 +9,10 @@ function params(m)
|
|||||||
params(ps, m)
|
params(ps, m)
|
||||||
return collect(ps)
|
return collect(ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
struct Param{T}
|
||||||
|
x::T
|
||||||
|
Δ::T
|
||||||
|
end
|
||||||
|
|
||||||
|
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
using Flux.Tracker: back!
|
||||||
|
|
||||||
function train!(m, data, opt; epoch = 1)
|
function train!(m, data, opt; epoch = 1)
|
||||||
for e in 1:epoch
|
for e in 1:epoch
|
||||||
epoch > 1 && info("Epoch $e")
|
epoch > 1 && info("Epoch $e")
|
||||||
|
Loading…
Reference in New Issue
Block a user