Flux.jl/src/compiler/code.jl

78 lines
1.9 KiB
Julia
Raw Normal View History

2016-08-18 21:31:49 +00:00
import Flow: mapconst, cse
2016-08-18 21:06:12 +00:00
2016-04-01 21:11:42 +00:00
function process_func(ex, params)
@capture(shortdef(ex), (args__,) -> body_)
2016-08-18 21:31:49 +00:00
body = Flow.il(graphm(body))
2016-08-18 21:06:12 +00:00
body = mapconst(x -> x in params ? :(self.$x) : x, body)
2016-04-01 21:11:42 +00:00
return args, body
end
2016-08-22 13:49:12 +00:00
immutable Parameter
name
end
function makegraph(graph, args)
@assert length(args) == 1
mapconst(graph) do x
x == args[1] ? Parameter(1) :
@capture(x, self.p_) ? Parameter(p) :
x
end
end
2016-04-01 21:11:42 +00:00
function build_type(T, params)
quote
2016-08-22 13:49:23 +00:00
type $T <: Model
2016-04-01 21:11:42 +00:00
$(params...)
end
2016-08-23 11:38:58 +00:00
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
end
end
function deref_params(v)
mapconst(v) do x
@capture(x, self.p_) ? :(state(self.$p)) : x
2016-04-01 21:11:42 +00:00
end
end
function build_forward(body, args)
2016-08-23 11:38:58 +00:00
cse(deref_params(body))
2016-04-01 21:11:42 +00:00
end
function build_backward(body, x, params)
2016-08-12 23:33:39 +00:00
Δs = invert(body)
2016-04-01 21:11:42 +00:00
back = IVertex{Any}(Flow.Do())
for param in params
haskey(Δs, :(self.$param)) || continue
ex = Δs[:(self.$param)]
2016-08-23 11:38:58 +00:00
ex = deref_params(ex)
thread!(back, @dvertex(accumulate!(:(self.$param), ex)))
2016-04-01 21:11:42 +00:00
end
ex = Δs[x]
2016-08-23 11:38:58 +00:00
ex = deref_params(ex)
2016-04-01 21:11:42 +00:00
thread!(back, @flow(tuple($ex)))
cse(back)
end
function process_type(ex)
@capture(ex, type T_ fs__ end)
2016-08-23 11:38:58 +00:00
@destruct [params = false || [],
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
2016-04-01 21:11:42 +00:00
@assert length(funcs) == 1
args, body = process_func(funcs[1], params)
@assert length(args) == 1
quote
$(build_type(T, params))
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
2016-08-23 11:38:58 +00:00
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...)
2016-08-22 13:49:12 +00:00
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
2016-08-22 13:49:34 +00:00
nothing
end |> esc
2016-04-01 21:11:42 +00:00
end
2016-08-22 13:49:34 +00:00
macro model(ex)
isexpr(ex, :type) ? process_type(ex) :
error("Unsupported model expression $ex")
end