use delta consistently

This commit is contained in:
Mike J Innes 2016-06-06 13:08:22 +01:00
parent 9a461c6824
commit 74e2551ee9

View File

@ -9,15 +9,15 @@ function process_func(ex, params)
@capture(shortdef(ex), (args__,) -> body_) @capture(shortdef(ex), (args__,) -> body_)
body = il(graphm(body)) body = il(graphm(body))
body = map(x -> x in params ? :(self.$x) : x, body) body = map(x -> x in params ? :(self.$x) : x, body)
= invert(body, @flow()) Δ = invert(body, @flow(Δ))
return args, body, return args, body, Δ
end end
function build_type(T, params, temps) function build_type(T, params, temps)
quote quote
type $T type $T
$(params...) $(params...)
$([symbol("", s) for s in params]...) $([symbol("Δ", s) for s in params]...)
$(temps...) $(temps...)
end end
$T($(params...)) = $T($(params...), $T($(params...)) = $T($(params...),
@ -36,25 +36,25 @@ function build_forward(body, temps)
cse(forward) cse(forward)
end end
function build_backward(s, x, params, temps) function build_backward(Δs, x, params, temps)
back = IVertex{Any}(Flow.Do()) back = IVertex{Any}(Flow.Do())
tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v) tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v)
for param in params for param in params
haskey(s, :(self.$param)) || continue haskey(Δs, :(self.$param)) || continue
k = symbol("", param) k = symbol("Δ", param)
ksym = Expr(:quote, k) ksym = Expr(:quote, k)
ex = tempify(s[:(self.$param)]) ex = tempify(Δs[:(self.$param)])
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex))) thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
end end
thread!(back, tempify(s[x])) thread!(back, tempify(Δs[x]))
cse(back) cse(back)
end end
function build_update(T, params) function build_update(T, params)
updates = [] updates = []
for p in params for p in params
∇p = symbol("", p) Δp = symbol("Δ", p)
push!(updates, :(self.$p += self.$∇p; fill!(self.$p, 0))) push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0)))
end end
:(update!(self::$T) = $(updates...)) :(update!(self::$T) = $(updates...))
end end
@ -64,13 +64,13 @@ function process_type(ex)
@destruct [params = true || [], @destruct [params = true || [],
funcs = false || []] = groupby(x->isa(x, Symbol), fs) funcs = false || []] = groupby(x->isa(x, Symbol), fs)
@assert length(funcs) == 1 @assert length(funcs) == 1
args, body, s = process_func(funcs[1], params) args, body, Δs = process_func(funcs[1], params)
@assert length(args) == 1 @assert length(args) == 1
temps = forward_temporaries(body, s) temps = forward_temporaries(body, Δs)
quote quote
$(build_type(T, params, collect(values(temps)))) $(build_type(T, params, collect(values(temps))))
(self::$T)($(args...),) = $(syntax(build_forward(body, temps))) (self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
back!(self::$T, , $(args...)) = $(syntax(build_backward(s, args[1], params, temps))) back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps)))
$(build_update(T, params)) $(build_update(T, params))
end |> longdef |> MacroTools.flatten end |> longdef |> MacroTools.flatten
end end