From 6b31a0745b2c9a62cce74fd774fb8c7d0881c222 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 22 Aug 2016 21:58:33 +0100 Subject: [PATCH] get forward pass working --- src/backend/mxnet/model.jl | 36 ++++++++++++++++++++++++++++++++---- src/backend/mxnet/mxnet.jl | 4 ---- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index eea80ad6..68cc3693 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -14,8 +14,36 @@ function mxargs(args) end end -function mxnet(model::Model, input) - vars = Dict{Symbol,Any}(:input => mx.zeros(mxdims(input))) - node = graph(vars, model, mx.Variable(:input)) - MXModel(model, vars, mx.bind(node, args = mxargs(vars), grad_req = mx.GRAD_NOP)) +function load!(model::MXModel) + for (name, arr) in model.exec.arg_dict + # TODO: don't allocate here + haskey(model.params, name) && mx.copy_ignore_shape!(arr, model.params[name]') + end + return model end + +function mxnet(model::Model, input) + vars = Dict{Symbol,Any}() + node = graph(vars, model, mx.Variable(:input)) + args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input)))) + model = MXModel(model, vars, mx.bind(node, args = args, grad_req = mx.GRAD_NOP)) + load!(model) + return model +end + +function (model::MXModel)(input) + inputnd = model.exec.arg_dict[:input] + mx.copy_ignore_shape!(inputnd, input') + mx.forward(model.exec) + copy(model.exec.outputs[1])' +end + +# d = Dense(20, 10) + +# x = randn(20) + +# model = mxnet(d, (20,)) + +# d(x) + +# model(x) diff --git a/src/backend/mxnet/mxnet.jl b/src/backend/mxnet/mxnet.jl index 14e23be7..ae532f85 100644 --- a/src/backend/mxnet/mxnet.jl +++ b/src/backend/mxnet/mxnet.jl @@ -5,8 +5,4 @@ using MXNet, Flow, ..Flux include("graph.jl") include("model.jl") -# d = Dense(20, 10) - -# model = mxnet(d, (1,20)) - end