build executor correctly
This commit is contained in:
parent
5df56b6073
commit
4df97bf607
@ -30,7 +30,7 @@ function mxparams(g::Graph)
|
||||
return params
|
||||
end
|
||||
|
||||
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
||||
ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
|
||||
|
||||
struct Exec
|
||||
graph::Graph
|
||||
@ -48,9 +48,18 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||
mxungroup(x, outs) = copy(shift!(outs))
|
||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||
|
||||
function executor(graph::Graph, input)
|
||||
args = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
|
||||
grads = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
|
||||
function dictt(ks::Tuple, vs, d = Dict())
|
||||
for i = 1:length(ks)
|
||||
dictt(ks[i], vs[i], d)
|
||||
end
|
||||
return d
|
||||
end
|
||||
|
||||
dictt(k, v, d = Dict()) = (d[k] = v; d)
|
||||
|
||||
function executor(graph::Graph, input...)
|
||||
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
exec = mx.bind(mxgroup(graph.output),
|
||||
args = ndparams(args),
|
||||
args_grad = ndparams(grads),
|
||||
@ -97,13 +106,13 @@ mxnet(model) = Model(model)
|
||||
|
||||
import Base: @get!
|
||||
|
||||
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
||||
executor(m::Model, input) = @get!(m.execs, size(input), executor(m.graph, input))
|
||||
|
||||
function (m::Model)(x)
|
||||
!isdefined(m, :graph) &&
|
||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
|
||||
@mxerr m.graph.stacks runrawbatched(x) do x
|
||||
m.last = exec = executor(m, size(x))
|
||||
m.last = exec = executor(m, x)
|
||||
exec(x)
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user