sgd
This commit is contained in:
parent
5d6e8e2777
commit
bafecfede1
@ -1,5 +1,10 @@
|
||||
module Optimise
|
||||
|
||||
using ..Tracker: TrackedArray, data, grad
|
||||
|
||||
export sgd, update!, params
|
||||
|
||||
include("params.jl")
|
||||
include("optimisers.jl")
|
||||
|
||||
end
|
||||
|
13
src/optimise/optimisers.jl
Normal file
13
src/optimise/optimisers.jl
Normal file
@ -0,0 +1,13 @@
|
||||
struct SGD
|
||||
ps::Vector{Any}
|
||||
η::Float32
|
||||
end
|
||||
|
||||
sgd(m, η) = SGD(params(m), η)
|
||||
|
||||
function update!(o::SGD)
|
||||
for p in o.ps
|
||||
x, Δ = data(p), grad(p)
|
||||
x .-= Δ .* o.η
|
||||
end
|
||||
end
|
@ -1,8 +1,7 @@
|
||||
children(x) = ()
|
||||
|
||||
using ..Tracker.TrackedArray
|
||||
using DataFlow: OSet
|
||||
|
||||
children(x) = ()
|
||||
|
||||
params(ps, p::TrackedArray) = push!(ps, p)
|
||||
|
||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||
|
Loading…
Reference in New Issue
Block a user