From 092f2038b3e1f78763743a688e4b8f002c3f1f42 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 21 Feb 2017 14:12:11 +0000 Subject: [PATCH] use affine only for compat --- src/backend/mxnet/graph.jl | 16 +++++++++------- src/backend/mxnet/model.jl | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index a4f879b9..ad238815 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -34,11 +34,12 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...) graph(::Input, x) = x graph(ctx::Context, d::Affine, x) = - register(ctx, - mx.FullyConnected(data = x, - num_hidden = size(d.W.x, 2), - weight = var(ctx, d.W), - bias = var(ctx, d.b, size(d.b, 2)))) + !ctx[:feedforward] ? invoke(graph, (Context, Any, typeof(x)), ctx, d, x) : + register(ctx, + mx.FullyConnected(data = x, + num_hidden = size(d.W.x, 2), + weight = var(ctx, d.W), + bias = var(ctx, d.b, size(d.b, 2)))) # TODO: use actual params} graph(ctx::Context, c::Conv2D, x) = @@ -79,9 +80,10 @@ end graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...) -function tograph(model, args...) +function tograph(model, args...; feedforward = false) ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′), - params = Dict(), stacks = Dict()) + params = Dict(), stacks = Dict(), + feedforward = feedforward) out = @ithrow graph(ctx, model, args...) return ctx[:params], ctx[:stacks], out end diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 5297a01d..9b371e9d 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -94,7 +94,7 @@ end function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) - vars, stacks, node = tograph(model, mx.Variable(input)) + vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true) ff = mx.FeedForward(node, context = context) isempty(vars) || (ff.arg_params = mxargs(vars)) return ff