diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index e6001543..0a1e76f8 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -90,6 +90,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont model = rewrite_softmax(model, label) vars, stacks, node = tograph(model, mx.Variable(input)) ff = mx.FeedForward(node, context = context) - ff.arg_params = mxargs(vars) + isempty(vars) || (ff.arg_params = mxargs(vars)) return ff end