This commit is contained in:
Mike J Innes 2017-08-22 22:25:18 +01:00
parent 5d6e8e2777
commit bafecfede1
3 changed files with 20 additions and 3 deletions

View File

@ -1,5 +1,10 @@
module Optimise
using ..Tracker: TrackedArray, data, grad
export sgd, update!, params
include("params.jl")
include("optimisers.jl")
end

View File

@ -0,0 +1,13 @@
struct SGD
ps::Vector{Any}
η::Float32
end
sgd(m, η) = SGD(params(m), η)
function update!(o::SGD)
for p in o.ps
x, Δ = data(p), grad(p)
x .-= Δ .* o.η
end
end

View File

@ -1,8 +1,7 @@
children(x) = ()
using ..Tracker.TrackedArray
using DataFlow: OSet
children(x) = ()
params(ps, p::TrackedArray) = push!(ps, p)
params(ps, m) = foreach(m -> params(ps, m), children(m))