Flux.jl/src/compiler/graph.jl

31 lines
637 B
Julia
Raw Normal View History

2016-10-25 22:10:35 +00:00
immutable ModelInput end
inputnode(n) = vertex(Split(n), constant(ModelInput()))
2016-10-25 23:39:16 +00:00
function bumpinputs(v::IVertex)
prewalk(v) do v
isa(value(v), Split) && value(v[1]) == Constant(ModelInput()) ?
inputnode(value(v).n + 1) :
v
end
end
2016-08-31 01:37:53 +00:00
2016-10-25 22:10:35 +00:00
function spliceinput(v::IVertex, input::IVertex)
2016-10-25 23:39:16 +00:00
postwalk(v) do v
2016-10-25 22:10:35 +00:00
value(v) == Constant(ModelInput()) ? input : v
end
end
2016-08-31 01:37:53 +00:00
2016-10-25 22:10:35 +00:00
spliceinputs(v::IVertex, inputs::Vertex...) =
spliceinput(v, vertex(Group(), inputs...))
2016-08-31 01:37:53 +00:00
2016-10-25 22:10:35 +00:00
function detuple(v::IVertex)
2016-10-25 23:39:16 +00:00
postwalk(v) do v
2016-10-25 22:10:35 +00:00
if isa(value(v), Split) && isa(value(v[1]), Group)
v[1][value(v).n]
else
2016-08-31 01:37:53 +00:00
v
2016-10-25 22:10:35 +00:00
end
2016-08-31 01:37:53 +00:00
end
end