new params approach

This commit is contained in:
Mike J Innes 2016-08-23 12:38:58 +01:00
parent 4ad4e20c3e
commit 8e92403436
3 changed files with 36 additions and 19 deletions

View File

@ -16,6 +16,7 @@ include("compiler/code.jl")
include("cost.jl")
include("activation.jl")
include("layers/params.jl")
include("layers/input.jl")
include("layers/dense.jl")
include("layers/sequence.jl")

View File

@ -24,15 +24,19 @@ function build_type(T, params)
quote
type $T <: Model
$(params...)
$([Symbol("Δ", s) for s in params]...)
end
$T($(params...)) = $T($(params...),
$((:(zeros($p)) for p in params)...))
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
end
end
function deref_params(v)
mapconst(v) do x
@capture(x, self.p_) ? :(state(self.$p)) : x
end
end
function build_forward(body, args)
cse(body)
cse(deref_params(body))
end
function build_backward(body, x, params)
@ -40,29 +44,20 @@ function build_backward(body, x, params)
back = IVertex{Any}(Flow.Do())
for param in params
haskey(Δs, :(self.$param)) || continue
k = Symbol("Δ", param)
ksym = Expr(:quote, k)
ex = Δs[:(self.$param)]
thread!(back, @dvertex(setfield!(:self, ksym, :(self.$k) + ex)))
ex = deref_params(ex)
thread!(back, @dvertex(accumulate!(:(self.$param), ex)))
end
ex = Δs[x]
ex = deref_params(ex)
thread!(back, @flow(tuple($ex)))
cse(back)
end
function build_update(T, params)
updates = []
for p in params
Δp = Symbol("Δ", p)
push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0)))
end
:(update!(self::$T) = $(updates...))
end
function process_type(ex)
@capture(ex, type T_ fs__ end)
@destruct [params = true || [],
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
@destruct [params = false || [],
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
@assert length(funcs) == 1
args, body = process_func(funcs[1], params)
@assert length(args) == 1
@ -70,7 +65,7 @@ function process_type(ex)
$(build_type(T, params))
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
$(build_update(T, params))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...)
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
nothing
end |> esc

21
src/layers/params.jl Normal file
View File

@ -0,0 +1,21 @@
type Param{T}
x::T
Δx::T
end
param(x) = Param(x, zero(x))
state(p::Param) = p
state(x) = x
function accumulate!(p::Param, Δ)
p.Δx += Δ
return p
end
function update!(p::Param, η)
p.x += p.Δx * η
return p
end
accumulate!(x, Δ) = x