From b95dae186841236607da06ae7d70c60b25e9b65a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 31 Aug 2017 12:36:18 -0400 Subject: [PATCH] opt refactor --- src/optimise/Optimise.jl | 8 ++++++-- src/optimise/optimisers.jl | 5 ++--- src/optimise/params.jl | 9 +++++++-- src/optimise/train.jl | 2 ++ 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index fd987861..5d9b7e89 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,11 +1,15 @@ module Optimise -using ..Tracker: TrackedArray, data, grad, back! - export sgd, update!, params, train! include("params.jl") include("optimisers.jl") include("train.jl") +using Flux.Tracker: TrackedArray + +params(ps, p::TrackedArray) = push!(ps, p) + +Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ) + end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 8c9db7a8..741e5b1a 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,5 +1,5 @@ struct SGD - ps::Vector{Any} + ps::Vector{Param} η::Float32 end @@ -7,8 +7,7 @@ sgd(m, η) = SGD(params(m), η) function update!(o::SGD) for p in o.ps - x, Δ = data(p), grad(p) - x .-= Δ .* o.η + p.x .-= p.Δ .* o.η Δ .= 0 end end diff --git a/src/optimise/params.jl b/src/optimise/params.jl index f7810fd5..c5163dbe 100644 --- a/src/optimise/params.jl +++ b/src/optimise/params.jl @@ -2,8 +2,6 @@ using DataFlow: OSet children(x) = () -params(ps, p::TrackedArray) = push!(ps, p) - params(ps, m) = foreach(m -> params(ps, m), children(m)) function params(m) @@ -11,3 +9,10 @@ function params(m) params(ps, m) return collect(ps) end + +struct Param{T} + x::T + Δ::T +end + +convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x)) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 9817a95d..ded47921 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,3 +1,5 @@ +using Flux.Tracker: back! + function train!(m, data, opt; epoch = 1) for e in 1:epoch epoch > 1 && info("Epoch $e")