From 54011045e788c070c5d23dc395575532ab3977f0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 20 Feb 2017 23:15:27 +0000 Subject: [PATCH] fix --- src/backend/mxnet/model.jl | 2 +- src/model.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 655ee425..e6001543 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -88,7 +88,7 @@ end function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) - vars, stacks, node = tograph(model, mx.Variable(:input)) + vars, stacks, node = tograph(model, mx.Variable(input)) ff = mx.FeedForward(node, context = context) ff.arg_params = mxargs(vars) return ff diff --git a/src/model.jl b/src/model.jl index 9bca1e8c..ca52154e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -92,7 +92,8 @@ end state(x) = x accumulate!(x, Δ) = x -@forward Param.x Base.size +Base.size(p::Param) = size(p.x) +Base.size(p::Param, n) = size(p.x, n) function Base.show(io::IO, p::Param) print(io, "Param", size(p.x))