fix mx.FeedForward
This commit is contained in:
parent
28202792bc
commit
ad4d60f90d
@ -82,13 +82,13 @@ graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
|
||||
function rewrite_softmax(model, name)
|
||||
model == softmax && return SoftmaxOutput(name)
|
||||
g = Flux.graph(model)
|
||||
(g == nothing || value(g) ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
||||
(g == nothing || g.value ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
||||
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||
end
|
||||
|
||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||
model = rewrite_softmax(model, label)
|
||||
node, vars = mxgraph(model, input)
|
||||
vars, stacks, node = tograph(model, mx.Variable(:input))
|
||||
ff = mx.FeedForward(node, context = context)
|
||||
ff.arg_params = mxargs(vars)
|
||||
return ff
|
||||
|
Loading…
Reference in New Issue
Block a user