use affine only for compat

This commit is contained in:
Mike J Innes 2017-02-21 14:12:11 +00:00
parent 0bb44f5ace
commit 092f2038b3
2 changed files with 10 additions and 8 deletions

View File

@ -34,11 +34,12 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...)
graph(::Input, x) = x graph(::Input, x) = x
graph(ctx::Context, d::Affine, x) = graph(ctx::Context, d::Affine, x) =
register(ctx, !ctx[:feedforward] ? invoke(graph, (Context, Any, typeof(x)), ctx, d, x) :
mx.FullyConnected(data = x, register(ctx,
num_hidden = size(d.W.x, 2), mx.FullyConnected(data = x,
weight = var(ctx, d.W), num_hidden = size(d.W.x, 2),
bias = var(ctx, d.b, size(d.b, 2)))) weight = var(ctx, d.W),
bias = var(ctx, d.b, size(d.b, 2))))
# TODO: use actual params} # TODO: use actual params}
graph(ctx::Context, c::Conv2D, x) = graph(ctx::Context, c::Conv2D, x) =
@ -79,9 +80,10 @@ end
graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...) graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
function tograph(model, args...) function tograph(model, args...; feedforward = false)
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph), ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph),
params = Dict(), stacks = Dict()) params = Dict(), stacks = Dict(),
feedforward = feedforward)
out = @ithrow graph(ctx, model, args...) out = @ithrow graph(ctx, model, args...)
return ctx[:params], ctx[:stacks], out return ctx[:params], ctx[:stacks], out
end end

View File

@ -94,7 +94,7 @@ end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
vars, stacks, node = tograph(model, mx.Variable(input)) vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(node, context = context) ff = mx.FeedForward(node, context = context)
isempty(vars) || (ff.arg_params = mxargs(vars)) isempty(vars) || (ff.arg_params = mxargs(vars))
return ff return ff