use delta consistently
This commit is contained in:
parent
9a461c6824
commit
74e2551ee9
@ -9,15 +9,15 @@ function process_func(ex, params)
|
||||
@capture(shortdef(ex), (args__,) -> body_)
|
||||
body = il(graphm(body))
|
||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||
∇ = invert(body, @flow(∇))
|
||||
return args, body, ∇
|
||||
Δ = invert(body, @flow(Δ))
|
||||
return args, body, Δ
|
||||
end
|
||||
|
||||
function build_type(T, params, temps)
|
||||
quote
|
||||
type $T
|
||||
$(params...)
|
||||
$([symbol("∇", s) for s in params]...)
|
||||
$([symbol("Δ", s) for s in params]...)
|
||||
$(temps...)
|
||||
end
|
||||
$T($(params...)) = $T($(params...),
|
||||
@ -36,25 +36,25 @@ function build_forward(body, temps)
|
||||
cse(forward)
|
||||
end
|
||||
|
||||
function build_backward(∇s, x, params, temps)
|
||||
function build_backward(Δs, x, params, temps)
|
||||
back = IVertex{Any}(Flow.Do())
|
||||
tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v)
|
||||
for param in params
|
||||
haskey(∇s, :(self.$param)) || continue
|
||||
k = symbol("∇", param)
|
||||
haskey(Δs, :(self.$param)) || continue
|
||||
k = symbol("Δ", param)
|
||||
ksym = Expr(:quote, k)
|
||||
ex = tempify(∇s[:(self.$param)])
|
||||
ex = tempify(Δs[:(self.$param)])
|
||||
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
||||
end
|
||||
thread!(back, tempify(∇s[x]))
|
||||
thread!(back, tempify(Δs[x]))
|
||||
cse(back)
|
||||
end
|
||||
|
||||
function build_update(T, params)
|
||||
updates = []
|
||||
for p in params
|
||||
∇p = symbol("∇", p)
|
||||
push!(updates, :(self.$p += self.$∇p; fill!(self.$∇p, 0)))
|
||||
Δp = symbol("Δ", p)
|
||||
push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0)))
|
||||
end
|
||||
:(update!(self::$T) = $(updates...))
|
||||
end
|
||||
@ -64,13 +64,13 @@ function process_type(ex)
|
||||
@destruct [params = true || [],
|
||||
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
||||
@assert length(funcs) == 1
|
||||
args, body, ∇s = process_func(funcs[1], params)
|
||||
args, body, Δs = process_func(funcs[1], params)
|
||||
@assert length(args) == 1
|
||||
temps = forward_temporaries(body, ∇s)
|
||||
temps = forward_temporaries(body, Δs)
|
||||
quote
|
||||
$(build_type(T, params, collect(values(temps))))
|
||||
(self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
|
||||
back!(self::$T, ∇, $(args...)) = $(syntax(build_backward(∇s, args[1], params, temps)))
|
||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps)))
|
||||
$(build_update(T, params))
|
||||
end |> longdef |> MacroTools.flatten
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user