fix mx.FeedForward

This commit is contained in:
Mike J Innes 2017-02-20 19:35:32 +00:00
parent 28202792bc
commit ad4d60f90d

View File

@ -82,13 +82,13 @@ graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
function rewrite_softmax(model, name)
model == softmax && return SoftmaxOutput(name)
g = Flux.graph(model)
(g == nothing || value(g) softmax || DataFlow.nin(g) 1) && error("mx.FeedForward models must end with `softmax`")
(g == nothing || g.value softmax || DataFlow.nin(g) 1) && error("mx.FeedForward models must end with `softmax`")
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label)
node, vars = mxgraph(model, input)
vars, stacks, node = tograph(model, mx.Variable(:input))
ff = mx.FeedForward(node, context = context)
ff.arg_params = mxargs(vars)
return ff