diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 2e2bc120..a3479bf6 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -8,7 +8,7 @@ end using Base: @get! using DataFlow: Constant, constant, Context, interpret, Split, - interpv, ituple, ilambda, iconst, iline, stack, mux + interpv, ituple, ilambda, iconst, iline, iargs, stack, mux using Flux: imap # TODO: implement Julia's type promotion rules @@ -61,7 +61,7 @@ end interp(ctx, p::Constant) = node(p.value) function graph(ctx::Context, model, args...) - node = graph(model, interpv(ctx, args)...) + node = graph(model, args...) isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx)) return node end @@ -70,11 +70,11 @@ function interp(ctx, model, args...) g = Flux.graph(model) g == nothing && return graph(ctx, model, args...) DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") - interpret(ctx, g, interpv(ctx, args)...) + interpret(ctx, g, args...) end function tograph(model, args...) - ctx = Context(mux(iline, ilambda, ituple, imap, interp), + ctx = Context(mux(iline, ilambda, imap, iargs, interp), params = Dict(), stacks = Dict()) out = interp(ctx, model, map(constant, args)...) return ctx[:params], ctx[:stacks], out