From f74ca7f7cf3965a0744c2bd612c462cb58b6cf58 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 20 Dec 2016 16:31:20 +0000 Subject: [PATCH] collect line info in tensorflow --- src/backend/tensorflow/graph.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 772e61c7..e50efca4 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -37,6 +37,12 @@ graph(p::MaxPool, x) = graph(op::Op, xs...) = op.f(xs...) +function graph(ctx::Context, model, args...) + node = graph(model, interpret(ctx, args)...) + isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx)) + return node +end + interp(ctx, c::Conv2D, x) = nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID") @@ -49,13 +55,14 @@ interp(ctx, p::Constant) = p.value function interp(ctx, model, args...) g = Flux.graph(model) - g == nothing && return graph(model, interpret(ctx, args)...) + g == nothing && return graph(ctx, model, args...) DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") interpret(ctx, g, interpret(ctx, args)...) end function tograph(model, args...) - ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict()) + ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), + params = ObjectIdDict(), stacks = Dict()) out = interp(ctx, model, map(constant, args)...) return ctx[:params], out end