fix a todo, houdini-style
This commit is contained in:
parent
c597d3a793
commit
199765354e
|
@ -1,5 +1,3 @@
|
|||
# TODO: proper escaping
|
||||
|
||||
import DataFlow: mapconst, cse
|
||||
|
||||
export @net, @ml
|
||||
|
@ -13,14 +11,20 @@ end
|
|||
|
||||
function makegraph(graph, args)
|
||||
@assert length(args) == 1
|
||||
mapconst(graph) do x
|
||||
x == args[1] ? inputnode(1) :
|
||||
isa(x, Offset) ? :(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
|
||||
x
|
||||
graph = prewalk(graph) do v
|
||||
isa(value(v), Constant) && value(v).value == args[1] ?
|
||||
inputnode(1) :
|
||||
v
|
||||
end
|
||||
graph = map(graph) do x
|
||||
isa(x, Offset) ?
|
||||
:(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
|
||||
x
|
||||
end
|
||||
end
|
||||
|
||||
function build_type(T, params)
|
||||
@esc T
|
||||
ex = quote
|
||||
type $T <: Model
|
||||
$(params...)
|
||||
|
@ -48,11 +52,11 @@ function process_type(ex)
|
|||
@assert length(args) == 1
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
(self::$T)($(args...),) = interpret(reifyparams(graph(self)), $(args...))
|
||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
|
||||
(self::$(esc(T)))($(args...),) = interpret(reifyparams(graph(self)), $(args...))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!(self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||
nothing
|
||||
end |> esc
|
||||
end
|
||||
end
|
||||
|
||||
macro net(ex)
|
||||
|
@ -64,12 +68,12 @@ end
|
|||
function process_anon(ex)
|
||||
args, body = process_func(ex)
|
||||
@assert length(args) == 1
|
||||
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args)))))
|
||||
:(Flux.Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args))))))
|
||||
end
|
||||
|
||||
macro ml(ex)
|
||||
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
|
||||
error("@ml requires a function definition")
|
||||
ex = process_anon(:($(xs...) -> $body))
|
||||
(f == nothing ? :($f = $ex) : ex) |> esc
|
||||
f == nothing ? :($(esc(f)) = $ex) : ex
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue