store graphs
This commit is contained in:
parent
e9c781cd10
commit
194bf6d3fb
|
@ -4,7 +4,7 @@ module Flux
|
|||
|
||||
# Zero Flux Given
|
||||
|
||||
using Juno, Requires
|
||||
using Juno, Requires, DataFlow
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
|
|
|
@ -1,11 +1,24 @@
|
|||
module Graph
|
||||
|
||||
using DataFlow, MacroTools
|
||||
using DataFlow: graphm, il, constructor, mapconst
|
||||
|
||||
export @net
|
||||
|
||||
graph(x) = nothing
|
||||
|
||||
function graphdef(m, T, args, body)
|
||||
v = il(graphm(body, args = args)) |> DataFlow.striplines
|
||||
:(Graph.graph($(esc(m))::$(esc(T))) = $(constructor(mapconst(esc, v))))
|
||||
end
|
||||
|
||||
macro net(ex)
|
||||
esc(ex)
|
||||
@capture(shortdef(ex), (m_::T_)(args__) = body_) ||
|
||||
error("@net requires a forward pass")
|
||||
quote
|
||||
$(esc(ex))
|
||||
$(graphdef(m, T, args, body))
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue