simplify graph
This commit is contained in:
parent
de72d83f7c
commit
cd0aa26b0e
@ -51,14 +51,6 @@ graph(::Input, x) = x
|
|||||||
# weight = graph(vars, d.W),
|
# weight = graph(vars, d.W),
|
||||||
# bias = graph(vars, d.b))
|
# bias = graph(vars, d.b))
|
||||||
|
|
||||||
function interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}})
|
|
||||||
id = gensym()
|
|
||||||
ctx[:params][id] = p.value.x
|
|
||||||
return mx.Variable(id)
|
|
||||||
end
|
|
||||||
|
|
||||||
interp(ctx, p::Constant) = node(p.value)
|
|
||||||
|
|
||||||
function register(ctx::Context, node::mx.SymbolicNode)
|
function register(ctx::Context, node::mx.SymbolicNode)
|
||||||
ctx[:stacks][nodename(node)] = stack(ctx)
|
ctx[:stacks][nodename(node)] = stack(ctx)
|
||||||
return node
|
return node
|
||||||
@ -66,21 +58,25 @@ end
|
|||||||
|
|
||||||
register(ctx::Context, node) = node
|
register(ctx::Context, node) = node
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
function graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}})
|
||||||
register(ctx, graph(model, args...))
|
id = gensym()
|
||||||
|
ctx[:params][id] = p.value.x
|
||||||
|
return mx.Variable(id)
|
||||||
end
|
end
|
||||||
|
|
||||||
function interp(ctx, model, args...)
|
graph(ctx::Context, p::Constant) = node(p.value)
|
||||||
|
|
||||||
|
function graph(ctx::Context, model, args...)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return graph(ctx, model, args...)
|
g == nothing && return register(ctx, graph(model, args...))
|
||||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||||
interpret(ctx, g, args...)
|
interpret(ctx, g, args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...)
|
||||||
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, interp),
|
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph),
|
||||||
params = Dict(), stacks = Dict())
|
params = Dict(), stacks = Dict())
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = graph(ctx, model, args...)
|
||||||
return ctx[:params], ctx[:stacks], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user