Flux.jl/src/backend/mxnet/mxarray.jl

40 lines
972 B
Julia
Raw Normal View History

2017-03-08 01:19:51 +00:00
using MXNet
# NDArray is row-major so by default all dimensions are reversed in MXNet.
# MXArray tranposes when loading/storing to fix this.
reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1)
immutable MXArray{N}
data::mx.NDArray
scratch::Array{Float32,N}
end
MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data)))
2017-03-08 15:36:38 +00:00
MXArray(dims::Dims) = MXArray(mx.zeros(reverse(dims)))
2017-03-08 01:19:51 +00:00
Base.size(xs::MXArray) = reverse(size(xs.data))
function Base.copy!(mx::MXArray, xs::AbstractArray)
@assert size(mx) == size(xs)
reversedims!(mx.scratch, xs)
copy!(mx.data, mx.scratch)
return mx
end
function Base.copy!(xs::AbstractArray, mx::MXArray)
@assert size(xs) == size(mx)
copy!(mx.scratch, mx.data)
reversedims!(xs, mx.scratch)
end
Base.copy(mx::MXArray) = copy!(Array{Float32}(size(mx)), mx)
function MXArray(xs::AbstractArray)
2017-03-08 15:36:38 +00:00
mx = MXArray(size(xs))
2017-03-08 01:19:51 +00:00
copy!(mx, xs)
end
2017-03-08 15:36:51 +00:00
Base.setindex!(xs::MXArray, x::Real, ::Colon) = xs.data[:] = x