collect line info in tensorflow

This commit is contained in:
Mike J Innes 2016-12-20 16:31:20 +00:00
parent 5f27e30e68
commit f74ca7f7cf

View File

@ -37,6 +37,12 @@ graph(p::MaxPool, x) =
graph(op::Op, xs...) = op.f(xs...)
function graph(ctx::Context, model, args...)
node = graph(model, interpret(ctx, args)...)
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
return node
end
interp(ctx, c::Conv2D, x) =
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
@ -49,13 +55,14 @@ interp(ctx, p::Constant) = p.value
function interp(ctx, model, args...)
g = Flux.graph(model)
g == nothing && return graph(model, interpret(ctx, args)...)
g == nothing && return graph(ctx, model, args...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, interpret(ctx, args)...)
end
function tograph(model, args...)
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict())
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
params = ObjectIdDict(), stacks = Dict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
end