awful hack to get both examples working

This commit is contained in:
Mike J Innes 2016-09-06 18:42:08 +01:00
parent 62ede8cd80
commit bec7219a93
2 changed files with 5 additions and 5 deletions

View File

@ -7,7 +7,7 @@ graph(vars, model, args...) = node(model, args...)
graph(vars, x::mx.SymbolicNode) = x graph(vars, x::mx.SymbolicNode) = x
# TODO: detect parameters used more than once # TODO: detect parameters used more than once
function graph{T<:AArray}(vars, p::Flux.Param{T}) function graph{T<:AArray}(vars::Associative, p::Flux.Param{T})
value = p.x value = p.x
id = gensym() id = gensym()
vars[id] = value vars[id] = value
@ -68,7 +68,7 @@ graph(vars, p::MaxPool, x) =
stride = p.stride) stride = p.stride)
# TODO: fix the initialisation issue # TODO: fix the initialisation issue
graph(vars, d::Dense, x) = graph(vars::Void, d::Dense, x) =
mx.FullyConnected(data = x, mx.FullyConnected(data = x,
num_hidden = size(d.W.x, 1), num_hidden = size(d.W.x, 1),
# weight = graph(vars, d.W), # weight = graph(vars, d.W),

View File

@ -47,8 +47,8 @@ function load!(model::MXModel)
return model return model
end end
function mxgraph(model, input) function mxgraph(model, input; vars = true)
vars = Dict{Symbol,Any}() vars = vars ? Dict{Symbol,Any}() : nothing
node = graph(vars, model, mx.Variable(input)) node = graph(vars, model, mx.Variable(input))
return node, vars return node, vars
end end
@ -91,6 +91,6 @@ end
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu()) function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
node, _ = mxgraph(model, input) node, _ = mxgraph(model, input, vars = false)
return mx.FeedForward(node, context = context) return mx.FeedForward(node, context = context)
end end