get forward pass working

This commit is contained in:
Mike J Innes 2016-08-22 21:58:33 +01:00
parent cab43611e3
commit 54dd0a1e0e
2 changed files with 32 additions and 8 deletions

View File

@ -14,8 +14,36 @@ function mxargs(args)
end
end
function mxnet(model::Model, input)
vars = Dict{Symbol,Any}(:input => mx.zeros(mxdims(input)))
node = graph(vars, model, mx.Variable(:input))
MXModel(model, vars, mx.bind(node, args = mxargs(vars), grad_req = mx.GRAD_NOP))
function load!(model::MXModel)
for (name, arr) in model.exec.arg_dict
# TODO: don't allocate here
haskey(model.params, name) && mx.copy_ignore_shape!(arr, model.params[name]')
end
return model
end
function mxnet(model::Model, input)
vars = Dict{Symbol,Any}()
node = graph(vars, model, mx.Variable(:input))
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
model = MXModel(model, vars, mx.bind(node, args = args, grad_req = mx.GRAD_NOP))
load!(model)
return model
end
function (model::MXModel)(input)
inputnd = model.exec.arg_dict[:input]
mx.copy_ignore_shape!(inputnd, input')
mx.forward(model.exec)
copy(model.exec.outputs[1])'
end
# d = Dense(20, 10)
# x = randn(20)
# model = mxnet(d, (20,))
# d(x)
# model(x)

View File

@ -5,8 +5,4 @@ using MXNet, Flow, ..Flux
include("graph.jl")
include("model.jl")
# d = Dense(20, 10)
# model = mxnet(d, (1,20))
end