fix
This commit is contained in:
parent
3fdffea37d
commit
54011045e7
@ -88,7 +88,7 @@ end
|
|||||||
|
|
||||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
vars, stacks, node = tograph(model, mx.Variable(:input))
|
vars, stacks, node = tograph(model, mx.Variable(input))
|
||||||
ff = mx.FeedForward(node, context = context)
|
ff = mx.FeedForward(node, context = context)
|
||||||
ff.arg_params = mxargs(vars)
|
ff.arg_params = mxargs(vars)
|
||||||
return ff
|
return ff
|
||||||
|
@ -92,7 +92,8 @@ end
|
|||||||
state(x) = x
|
state(x) = x
|
||||||
accumulate!(x, Δ) = x
|
accumulate!(x, Δ) = x
|
||||||
|
|
||||||
@forward Param.x Base.size
|
Base.size(p::Param) = size(p.x)
|
||||||
|
Base.size(p::Param, n) = size(p.x, n)
|
||||||
|
|
||||||
function Base.show(io::IO, p::Param)
|
function Base.show(io::IO, p::Param)
|
||||||
print(io, "Param", size(p.x))
|
print(io, "Param", size(p.x))
|
||||||
|
Loading…
Reference in New Issue
Block a user