diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 821247fb..ebc5c5b1 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -57,12 +57,14 @@ end register(ctx::Context, node) = node -function graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) +function var(ctx::Context, p::Flux.Param) id = gensym() - ctx[:params][id] = p.value.x + ctx[:params][id] = p.x return mx.Variable(id) end +graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value) + graph(ctx::Context, p::Constant) = node(p.value) function graph(ctx::Context, model, args...)