This commit is contained in:
Mike J Innes 2016-11-14 20:14:53 +00:00
parent 91652e5b44
commit c597d3a793

View File

@ -2,7 +2,7 @@
import DataFlow: mapconst, cse import DataFlow: mapconst, cse
export @net export @net, @ml
function process_func(ex, params = []) function process_func(ex, params = [])
@capture(shortdef(ex), (args__,) -> body_) @capture(shortdef(ex), (args__,) -> body_)
@ -55,14 +55,21 @@ function process_type(ex)
end |> esc end |> esc
end end
macro net(ex)
isexpr(ex, :type) ? process_type(ex) :
isexpr(ex, :->, :function) ? error("@net functions not implemented") :
error("Unsupported model expression $ex")
end
function process_anon(ex) function process_anon(ex)
args, body = process_func(ex) args, body = process_func(ex)
@assert length(args) == 1 @assert length(args) == 1
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args))))) |> esc :(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args)))))
end end
macro net(ex) macro ml(ex)
isexpr(ex, :type) ? process_type(ex) : @capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
isexpr(ex, :->, :function) ? process_anon(ex) : error("@ml requires a function definition")
error("Unsupported model expression $ex") ex = process_anon(:($(xs...) -> $body))
(f == nothing ? :($f = $ex) : ex) |> esc
end end