graph struct
This commit is contained in:
parent
f230b7cebf
commit
2f2ff0b03b
@ -85,7 +85,7 @@ function tograph(model, args...; feedforward = false)
|
|||||||
params = Dict(), stacks = Dict(),
|
params = Dict(), stacks = Dict(),
|
||||||
feedforward = feedforward)
|
feedforward = feedforward)
|
||||||
out = @ithrow graph(ctx, model, args...)
|
out = @ithrow graph(ctx, model, args...)
|
||||||
return ctx[:params], ctx[:stacks], out
|
return Graph(out, ctx[:params], ctx[:stacks])
|
||||||
end
|
end
|
||||||
|
|
||||||
# Error Handling
|
# Error Handling
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
using Flux: batchone, rebatch
|
using Flux: batchone, rebatch
|
||||||
|
|
||||||
|
type Graph
|
||||||
|
node::mx.SymbolicNode
|
||||||
|
params::Dict{Symbol,Any}
|
||||||
|
stacks::Dict{Any,Any}
|
||||||
|
end
|
||||||
|
|
||||||
type Model <: Flux.Model
|
type Model <: Flux.Model
|
||||||
model::Any
|
model::Any
|
||||||
params::Dict{Symbol,Any}
|
graph::Graph
|
||||||
grads::Dict{Symbol,Any}
|
grads::Dict{Symbol,Any}
|
||||||
stack::Dict{Any,Any}
|
|
||||||
exec::mx.Executor
|
exec::mx.Executor
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -30,19 +35,19 @@ end
|
|||||||
|
|
||||||
function loadparams!(model::Model)
|
function loadparams!(model::Model)
|
||||||
for (name, arr) in model.exec.arg_dict
|
for (name, arr) in model.exec.arg_dict
|
||||||
haskey(model.params, name) && copy!(arr, model.params[name])
|
haskey(model.graph.params, name) && copy!(arr, model.graph.params[name])
|
||||||
end
|
end
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxnet(model::Flux.Model, input)
|
function mxnet(model::Flux.Model, input)
|
||||||
params, stacks, node = tograph(model, mx.Variable(:input))
|
graph = tograph(model, mx.Variable(:input))
|
||||||
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
|
args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input)))
|
||||||
grads = mxgrads(args)
|
grads = mxgrads(args)
|
||||||
model = @mxerr stacks Model(model, params, grads, stacks,
|
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||||
mx.bind(node, args = args,
|
mx.bind(graph.node, args = args,
|
||||||
args_grad = grads,
|
args_grad = grads,
|
||||||
grad_req = mx.GRAD_ADD))
|
grad_req = mx.GRAD_ADD))
|
||||||
loadparams!(model)
|
loadparams!(model)
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
@ -94,8 +99,8 @@ end
|
|||||||
|
|
||||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true)
|
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||||
ff = mx.FeedForward(node, context = context)
|
ff = mx.FeedForward(graph.node, context = context)
|
||||||
isempty(vars) || (ff.arg_params = mxargs(vars))
|
isempty(graph.params) || (ff.arg_params = mxargs(graph.params))
|
||||||
return ff
|
return ff
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user