From bafecfede1efcf9362323eb62ec44767ad6157ed Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 22 Aug 2017 22:25:18 +0100 Subject: [PATCH] sgd --- src/optimise/Optimise.jl | 5 +++++ src/optimise/optimisers.jl | 13 +++++++++++++ src/optimise/params.jl | 5 ++--- 3 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 src/optimise/optimisers.jl diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index cc64bb31..8fe141f8 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,5 +1,10 @@ module Optimise +using ..Tracker: TrackedArray, data, grad + +export sgd, update!, params + include("params.jl") +include("optimisers.jl") end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl new file mode 100644 index 00000000..c333b3ae --- /dev/null +++ b/src/optimise/optimisers.jl @@ -0,0 +1,13 @@ +struct SGD + ps::Vector{Any} + η::Float32 +end + +sgd(m, η) = SGD(params(m), η) + +function update!(o::SGD) + for p in o.ps + x, Δ = data(p), grad(p) + x .-= Δ .* o.η + end +end diff --git a/src/optimise/params.jl b/src/optimise/params.jl index e3fff208..f7810fd5 100644 --- a/src/optimise/params.jl +++ b/src/optimise/params.jl @@ -1,8 +1,7 @@ -children(x) = () - -using ..Tracker.TrackedArray using DataFlow: OSet +children(x) = () + params(ps, p::TrackedArray) = push!(ps, p) params(ps, m) = foreach(m -> params(ps, m), children(m))