awful hack to get both examples working
This commit is contained in:
parent
62ede8cd80
commit
bec7219a93
@ -7,7 +7,7 @@ graph(vars, model, args...) = node(model, args...)
|
|||||||
graph(vars, x::mx.SymbolicNode) = x
|
graph(vars, x::mx.SymbolicNode) = x
|
||||||
|
|
||||||
# TODO: detect parameters used more than once
|
# TODO: detect parameters used more than once
|
||||||
function graph{T<:AArray}(vars, p::Flux.Param{T})
|
function graph{T<:AArray}(vars::Associative, p::Flux.Param{T})
|
||||||
value = p.x
|
value = p.x
|
||||||
id = gensym()
|
id = gensym()
|
||||||
vars[id] = value
|
vars[id] = value
|
||||||
@ -68,7 +68,7 @@ graph(vars, p::MaxPool, x) =
|
|||||||
stride = p.stride)
|
stride = p.stride)
|
||||||
|
|
||||||
# TODO: fix the initialisation issue
|
# TODO: fix the initialisation issue
|
||||||
graph(vars, d::Dense, x) =
|
graph(vars::Void, d::Dense, x) =
|
||||||
mx.FullyConnected(data = x,
|
mx.FullyConnected(data = x,
|
||||||
num_hidden = size(d.W.x, 1),
|
num_hidden = size(d.W.x, 1),
|
||||||
# weight = graph(vars, d.W),
|
# weight = graph(vars, d.W),
|
||||||
|
@ -47,8 +47,8 @@ function load!(model::MXModel)
|
|||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxgraph(model, input)
|
function mxgraph(model, input; vars = true)
|
||||||
vars = Dict{Symbol,Any}()
|
vars = vars ? Dict{Symbol,Any}() : nothing
|
||||||
node = graph(vars, model, mx.Variable(input))
|
node = graph(vars, model, mx.Variable(input))
|
||||||
return node, vars
|
return node, vars
|
||||||
end
|
end
|
||||||
@ -91,6 +91,6 @@ end
|
|||||||
|
|
||||||
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
|
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
node, _ = mxgraph(model, input)
|
node, _ = mxgraph(model, input, vars = false)
|
||||||
return mx.FeedForward(node, context = context)
|
return mx.FeedForward(node, context = context)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user