get forward pass working
This commit is contained in:
parent
cab43611e3
commit
54dd0a1e0e
@ -14,8 +14,36 @@ function mxargs(args)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxnet(model::Model, input)
|
function load!(model::MXModel)
|
||||||
vars = Dict{Symbol,Any}(:input => mx.zeros(mxdims(input)))
|
for (name, arr) in model.exec.arg_dict
|
||||||
node = graph(vars, model, mx.Variable(:input))
|
# TODO: don't allocate here
|
||||||
MXModel(model, vars, mx.bind(node, args = mxargs(vars), grad_req = mx.GRAD_NOP))
|
haskey(model.params, name) && mx.copy_ignore_shape!(arr, model.params[name]')
|
||||||
|
end
|
||||||
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function mxnet(model::Model, input)
|
||||||
|
vars = Dict{Symbol,Any}()
|
||||||
|
node = graph(vars, model, mx.Variable(:input))
|
||||||
|
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
||||||
|
model = MXModel(model, vars, mx.bind(node, args = args, grad_req = mx.GRAD_NOP))
|
||||||
|
load!(model)
|
||||||
|
return model
|
||||||
|
end
|
||||||
|
|
||||||
|
function (model::MXModel)(input)
|
||||||
|
inputnd = model.exec.arg_dict[:input]
|
||||||
|
mx.copy_ignore_shape!(inputnd, input')
|
||||||
|
mx.forward(model.exec)
|
||||||
|
copy(model.exec.outputs[1])'
|
||||||
|
end
|
||||||
|
|
||||||
|
# d = Dense(20, 10)
|
||||||
|
|
||||||
|
# x = randn(20)
|
||||||
|
|
||||||
|
# model = mxnet(d, (20,))
|
||||||
|
|
||||||
|
# d(x)
|
||||||
|
|
||||||
|
# model(x)
|
||||||
|
@ -5,8 +5,4 @@ using MXNet, Flow, ..Flux
|
|||||||
include("graph.jl")
|
include("graph.jl")
|
||||||
include("model.jl")
|
include("model.jl")
|
||||||
|
|
||||||
# d = Dense(20, 10)
|
|
||||||
|
|
||||||
# model = mxnet(d, (1,20))
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user