diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 56ececdc..655ee425 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -82,13 +82,13 @@ graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name) function rewrite_softmax(model, name) model == softmax && return SoftmaxOutput(name) g = Flux.graph(model) - (g == nothing || value(g) ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`") + (g == nothing || g.value ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`") return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) end function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) - node, vars = mxgraph(model, input) + vars, stacks, node = tograph(model, mx.Variable(:input)) ff = mx.FeedForward(node, context = context) ff.arg_params = mxargs(vars) return ff