build executor correctly

This commit is contained in:
Mike J Innes 2017-03-30 19:16:24 +01:00
parent 5df56b6073
commit 4df97bf607

View File

@ -30,7 +30,7 @@ function mxparams(g::Graph)
return params
end
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
struct Exec
graph::Graph
@ -48,9 +48,18 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function executor(graph::Graph, input)
args = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
grads = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
function dictt(ks::Tuple, vs, d = Dict())
for i = 1:length(ks)
dictt(ks[i], vs[i], d)
end
return d
end
dictt(k, v, d = Dict()) = (d[k] = v; d)
function executor(graph::Graph, input...)
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
exec = mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
@ -97,13 +106,13 @@ mxnet(model) = Model(model)
import Base: @get!
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
executor(m::Model, input) = @get!(m.execs, size(input), executor(m.graph, input))
function (m::Model)(x)
!isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
@mxerr m.graph.stacks runrawbatched(x) do x
m.last = exec = executor(m, size(x))
m.last = exec = executor(m, x)
exec(x)
end
end