From 74e2551ee928a36bc06f1853c15b56cd7e7061a7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 6 Jun 2016 13:08:22 +0100 Subject: [PATCH] use delta consistently --- src/rt/code.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/rt/code.jl b/src/rt/code.jl index 612596d5..d529835d 100644 --- a/src/rt/code.jl +++ b/src/rt/code.jl @@ -9,15 +9,15 @@ function process_func(ex, params) @capture(shortdef(ex), (args__,) -> body_) body = il(graphm(body)) body = map(x -> x in params ? :(self.$x) : x, body) - ∇ = invert(body, @flow(∇)) - return args, body, ∇ + Δ = invert(body, @flow(Δ)) + return args, body, Δ end function build_type(T, params, temps) quote type $T $(params...) - $([symbol("∇", s) for s in params]...) + $([symbol("Δ", s) for s in params]...) $(temps...) end $T($(params...)) = $T($(params...), @@ -36,25 +36,25 @@ function build_forward(body, temps) cse(forward) end -function build_backward(∇s, x, params, temps) +function build_backward(Δs, x, params, temps) back = IVertex{Any}(Flow.Do()) tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v) for param in params - haskey(∇s, :(self.$param)) || continue - k = symbol("∇", param) + haskey(Δs, :(self.$param)) || continue + k = symbol("Δ", param) ksym = Expr(:quote, k) - ex = tempify(∇s[:(self.$param)]) + ex = tempify(Δs[:(self.$param)]) thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex))) end - thread!(back, tempify(∇s[x])) + thread!(back, tempify(Δs[x])) cse(back) end function build_update(T, params) updates = [] for p in params - ∇p = symbol("∇", p) - push!(updates, :(self.$p += self.$∇p; fill!(self.$∇p, 0))) + Δp = symbol("Δ", p) + push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0))) end :(update!(self::$T) = $(updates...)) end @@ -64,13 +64,13 @@ function process_type(ex) @destruct [params = true || [], funcs = false || []] = groupby(x->isa(x, Symbol), fs) @assert length(funcs) == 1 - args, body, ∇s = process_func(funcs[1], params) + args, body, Δs = process_func(funcs[1], params) @assert length(args) == 1 - temps = forward_temporaries(body, ∇s) + temps = forward_temporaries(body, Δs) quote $(build_type(T, params, collect(values(temps)))) (self::$T)($(args...),) = $(syntax(build_forward(body, temps))) - back!(self::$T, ∇, $(args...)) = $(syntax(build_backward(∇s, args[1], params, temps))) + back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps))) $(build_update(T, params)) end |> longdef |> MacroTools.flatten end