collect line info in tensorflow

This commit is contained in:
Mike J Innes 2016-12-20 16:31:20 +00:00
parent 5f27e30e68
commit f74ca7f7cf

View File

@ -37,6 +37,12 @@ graph(p::MaxPool, x) =
graph(op::Op, xs...) = op.f(xs...) graph(op::Op, xs...) = op.f(xs...)
function graph(ctx::Context, model, args...)
node = graph(model, interpret(ctx, args)...)
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
return node
end
interp(ctx, c::Conv2D, x) = interp(ctx, c::Conv2D, x) =
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID") nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
@ -49,13 +55,14 @@ interp(ctx, p::Constant) = p.value
function interp(ctx, model, args...) function interp(ctx, model, args...)
g = Flux.graph(model) g = Flux.graph(model)
g == nothing && return graph(model, interpret(ctx, args)...) g == nothing && return graph(ctx, model, args...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, interpret(ctx, args)...) interpret(ctx, g, interpret(ctx, args)...)
end end
function tograph(model, args...) function tograph(model, args...)
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict()) ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
params = ObjectIdDict(), stacks = Dict())
out = interp(ctx, model, map(constant, args)...) out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out return ctx[:params], out
end end