diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index f235d7eb..e19706c8 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -51,14 +51,6 @@ graph(::Input, x) = x # weight = graph(vars, d.W), # bias = graph(vars, d.b)) -function interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) - id = gensym() - ctx[:params][id] = p.value.x - return mx.Variable(id) -end - -interp(ctx, p::Constant) = node(p.value) - function register(ctx::Context, node::mx.SymbolicNode) ctx[:stacks][nodename(node)] = stack(ctx) return node @@ -66,21 +58,25 @@ end register(ctx::Context, node) = node -function graph(ctx::Context, model, args...) - register(ctx, graph(model, args...)) +function graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) + id = gensym() + ctx[:params][id] = p.value.x + return mx.Variable(id) end -function interp(ctx, model, args...) +graph(ctx::Context, p::Constant) = node(p.value) + +function graph(ctx::Context, model, args...) g = Flux.graph(model) - g == nothing && return graph(ctx, model, args...) + g == nothing && return register(ctx, graph(model, args...)) DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") interpret(ctx, g, args...) end function tograph(model, args...) - ctx = Context(mux(iline, ilambda, imap, iargs, ituple, interp), + ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph), params = Dict(), stacks = Dict()) - out = interp(ctx, model, map(constant, args)...) + out = graph(ctx, model, args...) return ctx[:params], ctx[:stacks], out end