feedforward fix
This commit is contained in:
parent
9b18fd639a
commit
9c8dbb6b4b
@ -145,7 +145,7 @@ end
|
|||||||
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
|
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
graph = tograph(model, input, feedforward=true)
|
graph = tograph(model, input, feedforward=true)
|
||||||
ff = mx.FeedForward(graph.output, context = context)
|
ff = mx.FeedForward(graph.output, context = ctx)
|
||||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
|
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
|
||||||
return ff
|
return ff
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user