opt refactor

This commit is contained in:
Mike J Innes 2017-08-31 12:36:18 -04:00
parent 7cd13789dd
commit b95dae1868
4 changed files with 17 additions and 7 deletions

View File

@ -1,11 +1,15 @@
module Optimise module Optimise
using ..Tracker: TrackedArray, data, grad, back!
export sgd, update!, params, train! export sgd, update!, params, train!
include("params.jl") include("params.jl")
include("optimisers.jl") include("optimisers.jl")
include("train.jl") include("train.jl")
using Flux.Tracker: TrackedArray
params(ps, p::TrackedArray) = push!(ps, p)
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ)
end end

View File

@ -1,5 +1,5 @@
struct SGD struct SGD
ps::Vector{Any} ps::Vector{Param}
η::Float32 η::Float32
end end
@ -7,8 +7,7 @@ sgd(m, η) = SGD(params(m), η)
function update!(o::SGD) function update!(o::SGD)
for p in o.ps for p in o.ps
x, Δ = data(p), grad(p) p.x .-= p.Δ .* o.η
x .-= Δ .* o.η
Δ .= 0 Δ .= 0
end end
end end

View File

@ -2,8 +2,6 @@ using DataFlow: OSet
children(x) = () children(x) = ()
params(ps, p::TrackedArray) = push!(ps, p)
params(ps, m) = foreach(m -> params(ps, m), children(m)) params(ps, m) = foreach(m -> params(ps, m), children(m))
function params(m) function params(m)
@ -11,3 +9,10 @@ function params(m)
params(ps, m) params(ps, m)
return collect(ps) return collect(ps)
end end
struct Param{T}
x::T
Δ::T
end
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))

View File

@ -1,3 +1,5 @@
using Flux.Tracker: back!
function train!(m, data, opt; epoch = 1) function train!(m, data, opt; epoch = 1)
for e in 1:epoch for e in 1:epoch
epoch > 1 && info("Epoch $e") epoch > 1 && info("Epoch $e")