Flux.jl/src/compiler/code.jl

80 lines
2.2 KiB
Julia
Raw Normal View History

2016-10-31 12:38:18 +00:00
import DataFlow: mapconst, cse
2016-08-18 21:06:12 +00:00
2016-11-14 20:14:53 +00:00
export @net, @ml
2016-08-23 15:22:11 +00:00
function process_func(ex, params = [])
2016-04-01 21:11:42 +00:00
@capture(shortdef(ex), (args__,) -> body_)
2016-11-11 01:01:19 +00:00
body = @> body MacroTools.flatten block liftloops(params) graphm DataFlow.il
2016-08-18 21:06:12 +00:00
body = mapconst(x -> x in params ? :(self.$x) : x, body)
2016-04-01 21:11:42 +00:00
return args, body
end
2016-08-22 13:49:12 +00:00
function makegraph(graph, args)
@assert length(args) == 1
2016-11-14 21:56:40 +00:00
graph = prewalk(graph) do v
isa(value(v), Constant) && value(v).value == args[1] ?
inputnode(1) :
v
end
graph = map(graph) do x
isa(x, Offset) ?
:(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
x
2016-08-22 13:49:12 +00:00
end
end
2016-04-01 21:11:42 +00:00
function build_type(T, params)
2016-11-14 21:56:40 +00:00
@esc T
2016-08-23 13:14:20 +00:00
ex = quote
2016-08-22 13:49:23 +00:00
type $T <: Model
2016-04-01 21:11:42 +00:00
$(params...)
end
2016-08-23 11:38:58 +00:00
end
2016-08-23 13:14:20 +00:00
if any(x->isexpr(x, Symbol), params)
push!(ex.args,
:($T($(map(x->isexpr(x, Symbol) ? :($x::AArray) : x, params)...)) =
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
end
ex
2016-08-23 11:38:58 +00:00
end
2016-10-29 22:36:39 +00:00
import Lazy: groupby
2016-11-13 20:46:35 +00:00
reifyparams(v::IVertex) = mapconst(x -> isa(x, Param) ? x.x : x, v)
2016-04-01 21:11:42 +00:00
function process_type(ex)
@capture(ex, type T_ fs__ end)
2016-08-23 11:38:58 +00:00
@destruct [params = false || [],
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
2016-04-01 21:11:42 +00:00
@assert length(funcs) == 1
2016-08-23 13:14:20 +00:00
pnames = namify.(params)
args, body = process_func(funcs[1], pnames)
2016-04-01 21:11:42 +00:00
@assert length(args) == 1
quote
$(build_type(T, params))
2016-11-14 21:56:40 +00:00
(self::$(esc(T)))($(args...),) = interpret(reifyparams(graph(self)), $(args...))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!(self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
2016-08-22 13:49:34 +00:00
nothing
2016-11-14 21:56:40 +00:00
end
2016-04-01 21:11:42 +00:00
end
2016-11-14 20:14:53 +00:00
macro net(ex)
isexpr(ex, :type) ? process_type(ex) :
isexpr(ex, :->, :function) ? error("@net functions not implemented") :
error("Unsupported model expression $ex")
end
2016-08-23 15:22:11 +00:00
function process_anon(ex)
args, body = process_func(ex)
@assert length(args) == 1
2016-11-14 21:56:40 +00:00
:(Flux.Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args))))))
2016-08-23 15:22:11 +00:00
end
2016-11-14 20:14:53 +00:00
macro ml(ex)
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
error("@ml requires a function definition")
ex = process_anon(:($(xs...) -> $body))
2016-11-14 21:56:40 +00:00
f == nothing ? :($(esc(f)) = $ex) : ex
2016-08-22 13:49:34 +00:00
end