This commit is contained in:
Mike J Innes 2017-02-20 23:15:27 +00:00
parent 3fdffea37d
commit 54011045e7
2 changed files with 3 additions and 2 deletions

View File

@ -88,7 +88,7 @@ end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label)
vars, stacks, node = tograph(model, mx.Variable(:input))
vars, stacks, node = tograph(model, mx.Variable(input))
ff = mx.FeedForward(node, context = context)
ff.arg_params = mxargs(vars)
return ff

View File

@ -92,7 +92,8 @@ end
state(x) = x
accumulate!(x, Δ) = x
@forward Param.x Base.size
Base.size(p::Param) = size(p.x)
Base.size(p::Param, n) = size(p.x, n)
function Base.show(io::IO, p::Param)
print(io, "Param", size(p.x))