graph struct
This commit is contained in:
parent
f230b7cebf
commit
2f2ff0b03b
@ -85,7 +85,7 @@ function tograph(model, args...; feedforward = false)
|
||||
params = Dict(), stacks = Dict(),
|
||||
feedforward = feedforward)
|
||||
out = @ithrow graph(ctx, model, args...)
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
return Graph(out, ctx[:params], ctx[:stacks])
|
||||
end
|
||||
|
||||
# Error Handling
|
||||
|
@ -1,10 +1,15 @@
|
||||
using Flux: batchone, rebatch
|
||||
|
||||
type Graph
|
||||
node::mx.SymbolicNode
|
||||
params::Dict{Symbol,Any}
|
||||
stacks::Dict{Any,Any}
|
||||
end
|
||||
|
||||
type Model <: Flux.Model
|
||||
model::Any
|
||||
params::Dict{Symbol,Any}
|
||||
graph::Graph
|
||||
grads::Dict{Symbol,Any}
|
||||
stack::Dict{Any,Any}
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
@ -30,19 +35,19 @@ end
|
||||
|
||||
function loadparams!(model::Model)
|
||||
for (name, arr) in model.exec.arg_dict
|
||||
haskey(model.params, name) && copy!(arr, model.params[name])
|
||||
haskey(model.graph.params, name) && copy!(arr, model.graph.params[name])
|
||||
end
|
||||
return model
|
||||
end
|
||||
|
||||
function mxnet(model::Flux.Model, input)
|
||||
params, stacks, node = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input)))
|
||||
grads = mxgrads(args)
|
||||
model = @mxerr stacks Model(model, params, grads, stacks,
|
||||
mx.bind(node, args = args,
|
||||
args_grad = grads,
|
||||
grad_req = mx.GRAD_ADD))
|
||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||
mx.bind(graph.node, args = args,
|
||||
args_grad = grads,
|
||||
grad_req = mx.GRAD_ADD))
|
||||
loadparams!(model)
|
||||
return model
|
||||
end
|
||||
@ -94,8 +99,8 @@ end
|
||||
|
||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||
model = rewrite_softmax(model, label)
|
||||
vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true)
|
||||
ff = mx.FeedForward(node, context = context)
|
||||
isempty(vars) || (ff.arg_params = mxargs(vars))
|
||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||
ff = mx.FeedForward(graph.node, context = context)
|
||||
isempty(graph.params) || (ff.arg_params = mxargs(graph.params))
|
||||
return ff
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user