use affine only for compat
This commit is contained in:
parent
0bb44f5ace
commit
092f2038b3
@ -34,11 +34,12 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
|||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
graph(ctx::Context, d::Affine, x) =
|
graph(ctx::Context, d::Affine, x) =
|
||||||
register(ctx,
|
!ctx[:feedforward] ? invoke(graph, (Context, Any, typeof(x)), ctx, d, x) :
|
||||||
mx.FullyConnected(data = x,
|
register(ctx,
|
||||||
num_hidden = size(d.W.x, 2),
|
mx.FullyConnected(data = x,
|
||||||
weight = var(ctx, d.W),
|
num_hidden = size(d.W.x, 2),
|
||||||
bias = var(ctx, d.b, size(d.b, 2))))
|
weight = var(ctx, d.W),
|
||||||
|
bias = var(ctx, d.b, size(d.b, 2))))
|
||||||
|
|
||||||
# TODO: use actual params}
|
# TODO: use actual params}
|
||||||
graph(ctx::Context, c::Conv2D, x) =
|
graph(ctx::Context, c::Conv2D, x) =
|
||||||
@ -79,9 +80,10 @@ end
|
|||||||
|
|
||||||
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...; feedforward = false)
|
||||||
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′),
|
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′),
|
||||||
params = Dict(), stacks = Dict())
|
params = Dict(), stacks = Dict(),
|
||||||
|
feedforward = feedforward)
|
||||||
out = @ithrow graph(ctx, model, args...)
|
out = @ithrow graph(ctx, model, args...)
|
||||||
return ctx[:params], ctx[:stacks], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
end
|
end
|
||||||
|
@ -94,7 +94,7 @@ end
|
|||||||
|
|
||||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
vars, stacks, node = tograph(model, mx.Variable(input))
|
vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true)
|
||||||
ff = mx.FeedForward(node, context = context)
|
ff = mx.FeedForward(node, context = context)
|
||||||
isempty(vars) || (ff.arg_params = mxargs(vars))
|
isempty(vars) || (ff.arg_params = mxargs(vars))
|
||||||
return ff
|
return ff
|
||||||
|
Loading…
Reference in New Issue
Block a user