Flux.jl/src/compiler/code.jl

85 lines
2.0 KiB
Julia
Raw Normal View History

2016-04-01 21:11:42 +00:00
function process_func(ex, params)
@capture(shortdef(ex), (args__,) -> body_)
body = il(graphm(body))
body = map(x -> x in params ? :(self.$x) : x, body)
return args, body
end
function build_type(T, params)
quote
type $T
$(params...)
$([symbol("Δ", s) for s in params]...)
end
$T($(params...)) = $T($(params...),
$((:(zeros($p)) for p in params)...))
end
end
function build_forward(body, args)
body = cut_forward(body, args)
cse(body)
end
function build_backward(body, x, params)
Δs, Δloops = cut_backward(body, [x])
back = IVertex{Any}(Flow.Do())
for param in params
haskey(Δs, :(self.$param)) || continue
k = symbol("Δ", param)
ksym = Expr(:quote, k)
ex = Δs[:(self.$param)]
for Δloop in Δloops
ex = addΔ(ex, get(Δloop, :(self.$param), vertex(0)))
end
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
end
ex = Δs[x]
for Δloop in Δloops
ex = addΔ(ex, get(Δloop, x, vertex(0)))
end
thread!(back, @flow(tuple($ex)))
cse(back)
end
function build_update(T, params)
updates = []
for p in params
Δp = symbol("Δ", p)
push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0)))
end
:(update!(self::$T) = $(updates...))
end
function process_type(ex)
@capture(ex, type T_ fs__ end)
@destruct [params = true || [],
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
@assert length(funcs) == 1
args, body = process_func(funcs[1], params)
@assert length(args) == 1
quote
$(build_type(T, params))
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
$(build_update(T, params))
end |> longdef
end
# process_type(:(type Sigmoid
# W
# b
# bp
# x -> σ(W*x+b)
# end)) |> prettify
process_type(:(type Recurrent
Wxh; Whh; Bh
Why; By
function (x)
hidden = σ( Wxh*x + Whh*Delay(hidden) + Bh )
y = σ( Why*hidden + By )
end
end)) |> prettify