get forward pass working
This commit is contained in:
parent
cab43611e3
commit
54dd0a1e0e
|
@ -14,8 +14,36 @@ function mxargs(args)
|
|||
end
|
||||
end
|
||||
|
||||
function mxnet(model::Model, input)
|
||||
vars = Dict{Symbol,Any}(:input => mx.zeros(mxdims(input)))
|
||||
node = graph(vars, model, mx.Variable(:input))
|
||||
MXModel(model, vars, mx.bind(node, args = mxargs(vars), grad_req = mx.GRAD_NOP))
|
||||
function load!(model::MXModel)
|
||||
for (name, arr) in model.exec.arg_dict
|
||||
# TODO: don't allocate here
|
||||
haskey(model.params, name) && mx.copy_ignore_shape!(arr, model.params[name]')
|
||||
end
|
||||
return model
|
||||
end
|
||||
|
||||
function mxnet(model::Model, input)
|
||||
vars = Dict{Symbol,Any}()
|
||||
node = graph(vars, model, mx.Variable(:input))
|
||||
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
||||
model = MXModel(model, vars, mx.bind(node, args = args, grad_req = mx.GRAD_NOP))
|
||||
load!(model)
|
||||
return model
|
||||
end
|
||||
|
||||
function (model::MXModel)(input)
|
||||
inputnd = model.exec.arg_dict[:input]
|
||||
mx.copy_ignore_shape!(inputnd, input')
|
||||
mx.forward(model.exec)
|
||||
copy(model.exec.outputs[1])'
|
||||
end
|
||||
|
||||
# d = Dense(20, 10)
|
||||
|
||||
# x = randn(20)
|
||||
|
||||
# model = mxnet(d, (20,))
|
||||
|
||||
# d(x)
|
||||
|
||||
# model(x)
|
||||
|
|
|
@ -5,8 +5,4 @@ using MXNet, Flow, ..Flux
|
|||
include("graph.jl")
|
||||
include("model.jl")
|
||||
|
||||
# d = Dense(20, 10)
|
||||
|
||||
# model = mxnet(d, (1,20))
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue