simplify graph
This commit is contained in:
parent
de72d83f7c
commit
cd0aa26b0e
@ -51,14 +51,6 @@ graph(::Input, x) = x
|
||||
# weight = graph(vars, d.W),
|
||||
# bias = graph(vars, d.b))
|
||||
|
||||
function interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}})
|
||||
id = gensym()
|
||||
ctx[:params][id] = p.value.x
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
interp(ctx, p::Constant) = node(p.value)
|
||||
|
||||
function register(ctx::Context, node::mx.SymbolicNode)
|
||||
ctx[:stacks][nodename(node)] = stack(ctx)
|
||||
return node
|
||||
@ -66,21 +58,25 @@ end
|
||||
|
||||
register(ctx::Context, node) = node
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
register(ctx, graph(model, args...))
|
||||
function graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}})
|
||||
id = gensym()
|
||||
ctx[:params][id] = p.value.x
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
function interp(ctx, model, args...)
|
||||
graph(ctx::Context, p::Constant) = node(p.value)
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(ctx, model, args...)
|
||||
g == nothing && return register(ctx, graph(model, args...))
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
interpret(ctx, g, args...)
|
||||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, interp),
|
||||
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph),
|
||||
params = Dict(), stacks = Dict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
out = graph(ctx, model, args...)
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user