diff --git a/src/backend/mxnet/mxarray.jl b/src/backend/mxnet/mxarray.jl index f89ed734..ecfe7439 100644 --- a/src/backend/mxnet/mxarray.jl +++ b/src/backend/mxnet/mxarray.jl @@ -32,8 +32,8 @@ end Base.copy(mx::MXArray) = copy!(Array{Float32}(size(mx)), mx) -function MXArray(xs::AbstractArray) - mx = MXArray(size(xs)) +function MXArray(xs::AbstractArray, ctx = mx.cpu()) + mx = MXArray(size(xs), ctx) copy!(mx, xs) end