feedforward fix

This commit is contained in:
Mike J Innes 2017-06-09 18:54:35 +01:00
parent 9b18fd639a
commit 9c8dbb6b4b

View File

@ -145,7 +145,7 @@ end
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu()) function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
graph = tograph(model, input, feedforward=true) graph = tograph(model, input, feedforward=true)
ff = mx.FeedForward(graph.output, context = context) ff = mx.FeedForward(graph.output, context = ctx)
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx))) isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
return ff return ff
end end