graph struct

This commit is contained in:
Mike J Innes 2017-02-23 17:32:06 +00:00
parent f230b7cebf
commit 2f2ff0b03b
2 changed files with 18 additions and 13 deletions

View File

@ -85,7 +85,7 @@ function tograph(model, args...; feedforward = false)
params = Dict(), stacks = Dict(),
feedforward = feedforward)
out = @ithrow graph(ctx, model, args...)
return ctx[:params], ctx[:stacks], out
return Graph(out, ctx[:params], ctx[:stacks])
end
# Error Handling

View File

@ -1,10 +1,15 @@
using Flux: batchone, rebatch
type Graph
node::mx.SymbolicNode
params::Dict{Symbol,Any}
stacks::Dict{Any,Any}
end
type Model <: Flux.Model
model::Any
params::Dict{Symbol,Any}
graph::Graph
grads::Dict{Symbol,Any}
stack::Dict{Any,Any}
exec::mx.Executor
end
@ -30,19 +35,19 @@ end
function loadparams!(model::Model)
for (name, arr) in model.exec.arg_dict
haskey(model.params, name) && copy!(arr, model.params[name])
haskey(model.graph.params, name) && copy!(arr, model.graph.params[name])
end
return model
end
function mxnet(model::Flux.Model, input)
params, stacks, node = tograph(model, mx.Variable(:input))
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
graph = tograph(model, mx.Variable(:input))
args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args)
model = @mxerr stacks Model(model, params, grads, stacks,
mx.bind(node, args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))
model = @mxerr graph.stacks Model(model, graph, grads,
mx.bind(graph.node, args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))
loadparams!(model)
return model
end
@ -94,8 +99,8 @@ end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label)
vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(node, context = context)
isempty(vars) || (ff.arg_params = mxargs(vars))
graph = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(graph.node, context = context)
isempty(graph.params) || (ff.arg_params = mxargs(graph.params))
return ff
end