From d35191595d6fbb6680c5dd46a9bb8a4e4d0fba01 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 8 Mar 2017 01:19:51 +0000 Subject: [PATCH] mxarray --- src/backend/mxnet/mxarray.jl | 39 ++++++++++++++++++++++++++++++++++++ src/backend/mxnet/mxnet.jl | 1 + 2 files changed, 40 insertions(+) create mode 100644 src/backend/mxnet/mxarray.jl diff --git a/src/backend/mxnet/mxarray.jl b/src/backend/mxnet/mxarray.jl new file mode 100644 index 00000000..a7941ac7 --- /dev/null +++ b/src/backend/mxnet/mxarray.jl @@ -0,0 +1,39 @@ +using MXNet + +# NDArray is row-major so by default all dimensions are reversed in MXNet. +# MXArray tranposes when loading/storing to fix this. + +reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1) + +immutable MXArray{N} + data::mx.NDArray + scratch::Array{Float32,N} +end + +MXArray{T}(data::mx.NDArray, scratch::Array{Float32, T}) = MXArray{T}(data, scratch) + +MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data))) + +MXArray(dims::Integer...) = MXArray(mx.zeros(reverse(dims))) + +Base.size(xs::MXArray) = reverse(size(xs.data)) + +function Base.copy!(mx::MXArray, xs::AbstractArray) + @assert size(mx) == size(xs) + reversedims!(mx.scratch, xs) + copy!(mx.data, mx.scratch) + return mx +end + +function Base.copy!(xs::AbstractArray, mx::MXArray) + @assert size(xs) == size(mx) + copy!(mx.scratch, mx.data) + reversedims!(xs, mx.scratch) +end + +Base.copy(mx::MXArray) = copy!(Array{Float32}(size(mx)), mx) + +function MXArray(xs::AbstractArray) + mx = MXArray(size(xs)...) + copy!(mx, xs) +end diff --git a/src/backend/mxnet/mxnet.jl b/src/backend/mxnet/mxnet.jl index 8f8be89a..fb9db1dd 100644 --- a/src/backend/mxnet/mxnet.jl +++ b/src/backend/mxnet/mxnet.jl @@ -4,6 +4,7 @@ using MXNet, DataFlow, ..Flux export mxnet +include("mxarray.jl") include("graph.jl") include("model.jl")