From ad4d60f90db811a2b419f1c3e6d447a23638b897 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 20 Feb 2017 19:35:32 +0000 Subject: [PATCH] fix mx.FeedForward --- src/backend/mxnet/model.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 56ececdc..655ee425 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -82,13 +82,13 @@ graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name) function rewrite_softmax(model, name) model == softmax && return SoftmaxOutput(name) g = Flux.graph(model) - (g == nothing || value(g) ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`") + (g == nothing || g.value ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`") return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) end function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) - node, vars = mxgraph(model, input) + vars, stacks, node = tograph(model, mx.Variable(:input)) ff = mx.FeedForward(node, context = context) ff.arg_params = mxargs(vars) return ff