use delta consistently
This commit is contained in:
parent
9a461c6824
commit
74e2551ee9
@ -9,15 +9,15 @@ function process_func(ex, params)
|
|||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
body = il(graphm(body))
|
body = il(graphm(body))
|
||||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||||
∇ = invert(body, @flow(∇))
|
Δ = invert(body, @flow(Δ))
|
||||||
return args, body, ∇
|
return args, body, Δ
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_type(T, params, temps)
|
function build_type(T, params, temps)
|
||||||
quote
|
quote
|
||||||
type $T
|
type $T
|
||||||
$(params...)
|
$(params...)
|
||||||
$([symbol("∇", s) for s in params]...)
|
$([symbol("Δ", s) for s in params]...)
|
||||||
$(temps...)
|
$(temps...)
|
||||||
end
|
end
|
||||||
$T($(params...)) = $T($(params...),
|
$T($(params...)) = $T($(params...),
|
||||||
@ -36,25 +36,25 @@ function build_forward(body, temps)
|
|||||||
cse(forward)
|
cse(forward)
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_backward(∇s, x, params, temps)
|
function build_backward(Δs, x, params, temps)
|
||||||
back = IVertex{Any}(Flow.Do())
|
back = IVertex{Any}(Flow.Do())
|
||||||
tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v)
|
tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v)
|
||||||
for param in params
|
for param in params
|
||||||
haskey(∇s, :(self.$param)) || continue
|
haskey(Δs, :(self.$param)) || continue
|
||||||
k = symbol("∇", param)
|
k = symbol("Δ", param)
|
||||||
ksym = Expr(:quote, k)
|
ksym = Expr(:quote, k)
|
||||||
ex = tempify(∇s[:(self.$param)])
|
ex = tempify(Δs[:(self.$param)])
|
||||||
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
||||||
end
|
end
|
||||||
thread!(back, tempify(∇s[x]))
|
thread!(back, tempify(Δs[x]))
|
||||||
cse(back)
|
cse(back)
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_update(T, params)
|
function build_update(T, params)
|
||||||
updates = []
|
updates = []
|
||||||
for p in params
|
for p in params
|
||||||
∇p = symbol("∇", p)
|
Δp = symbol("Δ", p)
|
||||||
push!(updates, :(self.$p += self.$∇p; fill!(self.$∇p, 0)))
|
push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0)))
|
||||||
end
|
end
|
||||||
:(update!(self::$T) = $(updates...))
|
:(update!(self::$T) = $(updates...))
|
||||||
end
|
end
|
||||||
@ -64,13 +64,13 @@ function process_type(ex)
|
|||||||
@destruct [params = true || [],
|
@destruct [params = true || [],
|
||||||
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
||||||
@assert length(funcs) == 1
|
@assert length(funcs) == 1
|
||||||
args, body, ∇s = process_func(funcs[1], params)
|
args, body, Δs = process_func(funcs[1], params)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
temps = forward_temporaries(body, ∇s)
|
temps = forward_temporaries(body, Δs)
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params, collect(values(temps))))
|
$(build_type(T, params, collect(values(temps))))
|
||||||
(self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
|
(self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
|
||||||
back!(self::$T, ∇, $(args...)) = $(syntax(build_backward(∇s, args[1], params, temps)))
|
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps)))
|
||||||
$(build_update(T, params))
|
$(build_update(T, params))
|
||||||
end |> longdef |> MacroTools.flatten
|
end |> longdef |> MacroTools.flatten
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user