fix a todo, houdini-style

This commit is contained in:
Mike J Innes 2016-11-14 21:56:40 +00:00
parent c597d3a793
commit 199765354e

View File

@ -1,5 +1,3 @@
# TODO: proper escaping
import DataFlow: mapconst, cse import DataFlow: mapconst, cse
export @net, @ml export @net, @ml
@ -13,14 +11,20 @@ end
function makegraph(graph, args) function makegraph(graph, args)
@assert length(args) == 1 @assert length(args) == 1
mapconst(graph) do x graph = prewalk(graph) do v
x == args[1] ? inputnode(1) : isa(value(v), Constant) && value(v).value == args[1] ?
isa(x, Offset) ? :(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) : inputnode(1) :
x v
end
graph = map(graph) do x
isa(x, Offset) ?
:(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
x
end end
end end
function build_type(T, params) function build_type(T, params)
@esc T
ex = quote ex = quote
type $T <: Model type $T <: Model
$(params...) $(params...)
@ -48,11 +52,11 @@ function process_type(ex)
@assert length(args) == 1 @assert length(args) == 1
quote quote
$(build_type(T, params)) $(build_type(T, params))
(self::$T)($(args...),) = interpret(reifyparams(graph(self)), $(args...)) (self::$(esc(T)))($(args...),) = interpret(reifyparams(graph(self)), $(args...))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...) $(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!(self.$p, η)), pnames)...);)
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args))) $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing nothing
end |> esc end
end end
macro net(ex) macro net(ex)
@ -64,12 +68,12 @@ end
function process_anon(ex) function process_anon(ex)
args, body = process_func(ex) args, body = process_func(ex)
@assert length(args) == 1 @assert length(args) == 1
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args))))) :(Flux.Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args))))))
end end
macro ml(ex) macro ml(ex)
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) || @capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
error("@ml requires a function definition") error("@ml requires a function definition")
ex = process_anon(:($(xs...) -> $body)) ex = process_anon(:($(xs...) -> $body))
(f == nothing ? :($f = $ex) : ex) |> esc f == nothing ? :($(esc(f)) = $ex) : ex
end end