fix a todo, houdini-style

This commit is contained in:
Mike J Innes 2016-11-14 21:56:40 +00:00
parent c597d3a793
commit 199765354e
1 changed files with 16 additions and 12 deletions

View File

@ -1,5 +1,3 @@
# TODO: proper escaping
import DataFlow: mapconst, cse
export @net, @ml
@ -13,14 +11,20 @@ end
function makegraph(graph, args)
@assert length(args) == 1
mapconst(graph) do x
x == args[1] ? inputnode(1) :
isa(x, Offset) ? :(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
x
graph = prewalk(graph) do v
isa(value(v), Constant) && value(v).value == args[1] ?
inputnode(1) :
v
end
graph = map(graph) do x
isa(x, Offset) ?
:(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
x
end
end
function build_type(T, params)
@esc T
ex = quote
type $T <: Model
$(params...)
@ -48,11 +52,11 @@ function process_type(ex)
@assert length(args) == 1
quote
$(build_type(T, params))
(self::$T)($(args...),) = interpret(reifyparams(graph(self)), $(args...))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
(self::$(esc(T)))($(args...),) = interpret(reifyparams(graph(self)), $(args...))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!(self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing
end |> esc
end
end
macro net(ex)
@ -64,12 +68,12 @@ end
function process_anon(ex)
args, body = process_func(ex)
@assert length(args) == 1
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args)))))
:(Flux.Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args))))))
end
macro ml(ex)
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
error("@ml requires a function definition")
ex = process_anon(:($(xs...) -> $body))
(f == nothing ? :($f = $ex) : ex) |> esc
f == nothing ? :($(esc(f)) = $ex) : ex
end