store graphs
This commit is contained in:
parent
e9c781cd10
commit
194bf6d3fb
@ -4,7 +4,7 @@ module Flux
|
|||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
using Juno, Requires
|
using Juno, Requires, DataFlow
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM,
|
||||||
|
@ -1,11 +1,24 @@
|
|||||||
module Graph
|
module Graph
|
||||||
|
|
||||||
|
using DataFlow, MacroTools
|
||||||
|
using DataFlow: graphm, il, constructor, mapconst
|
||||||
|
|
||||||
export @net
|
export @net
|
||||||
|
|
||||||
graph(x) = nothing
|
graph(x) = nothing
|
||||||
|
|
||||||
|
function graphdef(m, T, args, body)
|
||||||
|
v = il(graphm(body, args = args)) |> DataFlow.striplines
|
||||||
|
:(Graph.graph($(esc(m))::$(esc(T))) = $(constructor(mapconst(esc, v))))
|
||||||
|
end
|
||||||
|
|
||||||
macro net(ex)
|
macro net(ex)
|
||||||
esc(ex)
|
@capture(shortdef(ex), (m_::T_)(args__) = body_) ||
|
||||||
|
error("@net requires a forward pass")
|
||||||
|
quote
|
||||||
|
$(esc(ex))
|
||||||
|
$(graphdef(m, T, args, body))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user