simplify graph

This commit is contained in:
Mike J Innes 2017-01-30 23:19:18 +05:30
parent de72d83f7c
commit cd0aa26b0e

View File

@ -51,14 +51,6 @@ graph(::Input, x) = x
# weight = graph(vars, d.W),
# bias = graph(vars, d.b))
function interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}})
id = gensym()
ctx[:params][id] = p.value.x
return mx.Variable(id)
end
interp(ctx, p::Constant) = node(p.value)
function register(ctx::Context, node::mx.SymbolicNode)
ctx[:stacks][nodename(node)] = stack(ctx)
return node
@ -66,21 +58,25 @@ end
register(ctx::Context, node) = node
function graph(ctx::Context, model, args...)
register(ctx, graph(model, args...))
function graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}})
id = gensym()
ctx[:params][id] = p.value.x
return mx.Variable(id)
end
function interp(ctx, model, args...)
graph(ctx::Context, p::Constant) = node(p.value)
function graph(ctx::Context, model, args...)
g = Flux.graph(model)
g == nothing && return graph(ctx, model, args...)
g == nothing && return register(ctx, graph(model, args...))
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, args...)
end
function tograph(model, args...)
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, interp),
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph),
params = Dict(), stacks = Dict())
out = interp(ctx, model, map(constant, args)...)
out = graph(ctx, model, args...)
return ctx[:params], ctx[:stacks], out
end