ml macro
This commit is contained in:
parent
91652e5b44
commit
c597d3a793
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import DataFlow: mapconst, cse
|
import DataFlow: mapconst, cse
|
||||||
|
|
||||||
export @net
|
export @net, @ml
|
||||||
|
|
||||||
function process_func(ex, params = [])
|
function process_func(ex, params = [])
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
@ -55,14 +55,21 @@ function process_type(ex)
|
|||||||
end |> esc
|
end |> esc
|
||||||
end
|
end
|
||||||
|
|
||||||
|
macro net(ex)
|
||||||
|
isexpr(ex, :type) ? process_type(ex) :
|
||||||
|
isexpr(ex, :->, :function) ? error("@net functions not implemented") :
|
||||||
|
error("Unsupported model expression $ex")
|
||||||
|
end
|
||||||
|
|
||||||
function process_anon(ex)
|
function process_anon(ex)
|
||||||
args, body = process_func(ex)
|
args, body = process_func(ex)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args))))) |> esc
|
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args)))))
|
||||||
end
|
end
|
||||||
|
|
||||||
macro net(ex)
|
macro ml(ex)
|
||||||
isexpr(ex, :type) ? process_type(ex) :
|
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
|
||||||
isexpr(ex, :->, :function) ? process_anon(ex) :
|
error("@ml requires a function definition")
|
||||||
error("Unsupported model expression $ex")
|
ex = process_anon(:($(xs...) -> $body))
|
||||||
|
(f == nothing ? :($f = $ex) : ex) |> esc
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user