collect line info in tensorflow
This commit is contained in:
parent
5f27e30e68
commit
f74ca7f7cf
@ -37,6 +37,12 @@ graph(p::MaxPool, x) =
|
||||
|
||||
graph(op::Op, xs...) = op.f(xs...)
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
node = graph(model, interpret(ctx, args)...)
|
||||
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
|
||||
return node
|
||||
end
|
||||
|
||||
interp(ctx, c::Conv2D, x) =
|
||||
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||
|
||||
@ -49,13 +55,14 @@ interp(ctx, p::Constant) = p.value
|
||||
|
||||
function interp(ctx, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(model, interpret(ctx, args)...)
|
||||
g == nothing && return graph(ctx, model, args...)
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
interpret(ctx, g, interpret(ctx, args)...)
|
||||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict())
|
||||
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
|
||||
params = ObjectIdDict(), stacks = Dict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], out
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user