diff options
author | Matthew Bentham <matthew.bentham@arm.com> | 2019-09-14 23:35:28 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2019-09-16 09:07:06 +0000 |
commit | efdbca6b1a25dbd491be7f308dd729ff2255bb28 (patch) | |
tree | eaaea26eea415411715c0666e28327300a9ac1e8 /src/backends | |
parent | 5bf1d321e480b9c0030368edd5138fec949975c1 (diff) | |
download | armnn-efdbca6b1a25dbd491be7f308dd729ff2255bb28.tar.gz |
Rename variables in CopyTensorContents to assume NHWC
Change-Id: I533991c8829256570529c18023a5e882878cc85a
Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/backendsCommon/WorkloadUtils.hpp | 40 |
1 files changed, 21 insertions, 19 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp index ba69255183..3e0c40d890 100644 --- a/src/backends/backendsCommon/WorkloadUtils.hpp +++ b/src/backends/backendsCommon/WorkloadUtils.hpp @@ -46,6 +46,8 @@ void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& template <typename CopyFunc> void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy) { + // For ease of understanding, names are assigned to the dimensions + // of the tensor as if NHWC, however this routine works with any 5D tensor static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents"); TensorShape srcStrides = srcTensor->GetStrides(); @@ -55,57 +57,57 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds size_t srcDepth = 1; size_t srcBatches = 1; - size_t srcChannels = 1; size_t srcHeight = 1; size_t srcWidth = 1; + size_t srcChannels = 1; AssignValues(srcShape.GetNumDimensions(), 0, srcShape, + srcChannels, srcWidth, srcHeight, - srcChannels, srcBatches, srcDepth); size_t srcDepthStride = 0; size_t srcBatchStride = 0; - size_t srcChannelStride = 0; size_t srcHeightStride = 0; size_t srcWidthStride = 0; + size_t srcChannelStride = 0; AssignValues(srcStrides.GetNumDimensions(), 0, srcStrides, + srcChannelStride, srcWidthStride, srcHeightStride, - srcChannelStride, srcBatchStride, srcDepthStride); size_t dstDepth = 1; size_t dstBatches = 1; - size_t dstChannels = 1; size_t dstHeight = 1; size_t dstWidth = 1; + size_t dstChannels = 1; AssignValues(dstShape.GetNumDimensions(), 0, dstShape, + dstChannels, dstWidth, dstHeight, - dstChannels, dstBatches, dstDepth); size_t dstDepthStride = 0; size_t dstBatchStride = 0; - size_t dstChannelStride = 0; size_t dstHeightStride = 0; size_t dstWidthStride = 0; + size_t dstChannelStride = 0; AssignValues(dstStrides.GetNumDimensions(), 0, dstStrides, + dstChannelStride, dstWidthStride, dstHeightStride, - dstChannelStride, dstBatchStride, dstDepthStride); @@ -117,11 +119,11 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds dstData = static_cast<uint8_t*>(dstTensor->Map()); } - size_t copyLength = std::min(srcWidth * srcWidthStride, dstWidth * dstWidthStride); - size_t copyHeight = std::min(srcHeight, dstHeight); - size_t copyChannels = std::min(srcChannels, dstChannels); - size_t copyBatches = std::min(srcBatches, dstBatches); - size_t copyDepth = std::min(srcDepth, dstDepth); + size_t copyLength = std::min(srcChannels*srcChannelStride, dstChannels*dstChannelStride); + size_t copyWidth = std::min(srcWidth, dstWidth); + size_t copyHeight = std::min(srcHeight, dstHeight); + size_t copyBatches = std::min(srcBatches, dstBatches); + size_t copyDepth = std::min(srcDepth, dstDepth); for (unsigned int d = 0; d < copyDepth; ++d) { @@ -131,18 +133,18 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds { auto srcPtrBatch = srcData; auto dstPtrBatch = dstData; - for (unsigned int c = 0; c < copyChannels; ++c) + for (unsigned int h = 0; h < copyHeight; ++h) { auto srcPtrChannel = srcData; auto dstPtrChannel = dstData; - for (unsigned int h = 0; h < copyHeight; ++h) + for (unsigned int w = 0; w < copyWidth; ++w) { copy(dstData, srcData, copyLength); - dstData += dstHeightStride; - srcData += srcHeightStride; + dstData += dstWidthStride; + srcData += srcWidthStride; } - dstData += (static_cast<long>(dstChannelStride) - (dstData - dstPtrChannel)); - srcData += (static_cast<long>(srcChannelStride) - (srcData - srcPtrChannel)); + dstData += (static_cast<long>(dstHeightStride) - (dstData - dstPtrChannel)); + srcData += (static_cast<long>(srcHeightStride) - (srcData - srcPtrChannel)); } dstData += (static_cast<long>(dstBatchStride) - (dstData - dstPtrBatch)); srcData += (static_cast<long>(srcBatchStride) - (srcData - srcPtrBatch)); |