fix a todo, houdini-style
This commit is contained in:
parent
c597d3a793
commit
199765354e
@ -1,5 +1,3 @@
|
|||||||
# TODO: proper escaping
|
|
||||||
|
|
||||||
import DataFlow: mapconst, cse
|
import DataFlow: mapconst, cse
|
||||||
|
|
||||||
export @net, @ml
|
export @net, @ml
|
||||||
@ -13,14 +11,20 @@ end
|
|||||||
|
|
||||||
function makegraph(graph, args)
|
function makegraph(graph, args)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
mapconst(graph) do x
|
graph = prewalk(graph) do v
|
||||||
x == args[1] ? inputnode(1) :
|
isa(value(v), Constant) && value(v).value == args[1] ?
|
||||||
isa(x, Offset) ? :(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
|
inputnode(1) :
|
||||||
x
|
v
|
||||||
|
end
|
||||||
|
graph = map(graph) do x
|
||||||
|
isa(x, Offset) ?
|
||||||
|
:(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
|
||||||
|
x
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_type(T, params)
|
function build_type(T, params)
|
||||||
|
@esc T
|
||||||
ex = quote
|
ex = quote
|
||||||
type $T <: Model
|
type $T <: Model
|
||||||
$(params...)
|
$(params...)
|
||||||
@ -48,11 +52,11 @@ function process_type(ex)
|
|||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
(self::$T)($(args...),) = interpret(reifyparams(graph(self)), $(args...))
|
(self::$(esc(T)))($(args...),) = interpret(reifyparams(graph(self)), $(args...))
|
||||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!(self.$p, η)), pnames)...);)
|
||||||
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||||
nothing
|
nothing
|
||||||
end |> esc
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
macro net(ex)
|
macro net(ex)
|
||||||
@ -64,12 +68,12 @@ end
|
|||||||
function process_anon(ex)
|
function process_anon(ex)
|
||||||
args, body = process_func(ex)
|
args, body = process_func(ex)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args)))))
|
:(Flux.Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args))))))
|
||||||
end
|
end
|
||||||
|
|
||||||
macro ml(ex)
|
macro ml(ex)
|
||||||
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
|
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
|
||||||
error("@ml requires a function definition")
|
error("@ml requires a function definition")
|
||||||
ex = process_anon(:($(xs...) -> $body))
|
ex = process_anon(:($(xs...) -> $body))
|
||||||
(f == nothing ? :($f = $ex) : ex) |> esc
|
f == nothing ? :($(esc(f)) = $ex) : ex
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user