graph struct

This commit is contained in:
Mike J Innes 2017-02-23 17:32:06 +00:00
parent f230b7cebf
commit 2f2ff0b03b
2 changed files with 18 additions and 13 deletions

View File

@ -85,7 +85,7 @@ function tograph(model, args...; feedforward = false)
params = Dict(), stacks = Dict(), params = Dict(), stacks = Dict(),
feedforward = feedforward) feedforward = feedforward)
out = @ithrow graph(ctx, model, args...) out = @ithrow graph(ctx, model, args...)
return ctx[:params], ctx[:stacks], out return Graph(out, ctx[:params], ctx[:stacks])
end end
# Error Handling # Error Handling

View File

@ -1,10 +1,15 @@
using Flux: batchone, rebatch using Flux: batchone, rebatch
type Graph
node::mx.SymbolicNode
params::Dict{Symbol,Any}
stacks::Dict{Any,Any}
end
type Model <: Flux.Model type Model <: Flux.Model
model::Any model::Any
params::Dict{Symbol,Any} graph::Graph
grads::Dict{Symbol,Any} grads::Dict{Symbol,Any}
stack::Dict{Any,Any}
exec::mx.Executor exec::mx.Executor
end end
@ -30,19 +35,19 @@ end
function loadparams!(model::Model) function loadparams!(model::Model)
for (name, arr) in model.exec.arg_dict for (name, arr) in model.exec.arg_dict
haskey(model.params, name) && copy!(arr, model.params[name]) haskey(model.graph.params, name) && copy!(arr, model.graph.params[name])
end end
return model return model
end end
function mxnet(model::Flux.Model, input) function mxnet(model::Flux.Model, input)
params, stacks, node = tograph(model, mx.Variable(:input)) graph = tograph(model, mx.Variable(:input))
args = merge(mxargs(params), Dict(:input => mx.zeros(input))) args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args) grads = mxgrads(args)
model = @mxerr stacks Model(model, params, grads, stacks, model = @mxerr graph.stacks Model(model, graph, grads,
mx.bind(node, args = args, mx.bind(graph.node, args = args,
args_grad = grads, args_grad = grads,
grad_req = mx.GRAD_ADD)) grad_req = mx.GRAD_ADD))
loadparams!(model) loadparams!(model)
return model return model
end end
@ -94,8 +99,8 @@ end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true) graph = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(node, context = context) ff = mx.FeedForward(graph.node, context = context)
isempty(vars) || (ff.arg_params = mxargs(vars)) isempty(graph.params) || (ff.arg_params = mxargs(graph.params))
return ff return ff
end end