opt refactor

This commit is contained in:
Mike J Innes 2017-08-31 12:36:18 -04:00
parent 7cd13789dd
commit b95dae1868
4 changed files with 17 additions and 7 deletions

View File

@ -1,11 +1,15 @@
module Optimise
using ..Tracker: TrackedArray, data, grad, back!
export sgd, update!, params, train!
include("params.jl")
include("optimisers.jl")
include("train.jl")
using Flux.Tracker: TrackedArray
params(ps, p::TrackedArray) = push!(ps, p)
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ)
end

View File

@ -1,5 +1,5 @@
struct SGD
ps::Vector{Any}
ps::Vector{Param}
η::Float32
end
@ -7,8 +7,7 @@ sgd(m, η) = SGD(params(m), η)
function update!(o::SGD)
for p in o.ps
x, Δ = data(p), grad(p)
x .-= Δ .* o.η
p.x .-= p.Δ .* o.η
Δ .= 0
end
end

View File

@ -2,8 +2,6 @@ using DataFlow: OSet
children(x) = ()
params(ps, p::TrackedArray) = push!(ps, p)
params(ps, m) = foreach(m -> params(ps, m), children(m))
function params(m)
@ -11,3 +9,10 @@ function params(m)
params(ps, m)
return collect(ps)
end
struct Param{T}
x::T
Δ::T
end
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))

View File

@ -1,3 +1,5 @@
using Flux.Tracker: back!
function train!(m, data, opt; epoch = 1)
for e in 1:epoch
epoch > 1 && info("Epoch $e")