aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2019-09-14 23:35:28 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2019-09-16 09:07:06 +0000
commitefdbca6b1a25dbd491be7f308dd729ff2255bb28 (patch)
treeeaaea26eea415411715c0666e28327300a9ac1e8 /src
parent5bf1d321e480b9c0030368edd5138fec949975c1 (diff)
downloadarmnn-efdbca6b1a25dbd491be7f308dd729ff2255bb28.tar.gz
Rename variables in CopyTensorContents to assume NHWC
Change-Id: I533991c8829256570529c18023a5e882878cc85a Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/backends/backendsCommon/WorkloadUtils.hpp40
1 files changed, 21 insertions, 19 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp
index ba69255183..3e0c40d890 100644
--- a/src/backends/backendsCommon/WorkloadUtils.hpp
+++ b/src/backends/backendsCommon/WorkloadUtils.hpp
@@ -46,6 +46,8 @@ void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T&
template <typename CopyFunc>
void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy)
{
+ // For ease of understanding, names are assigned to the dimensions
+ // of the tensor as if NHWC, however this routine works with any 5D tensor
static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents");
TensorShape srcStrides = srcTensor->GetStrides();
@@ -55,57 +57,57 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds
size_t srcDepth = 1;
size_t srcBatches = 1;
- size_t srcChannels = 1;
size_t srcHeight = 1;
size_t srcWidth = 1;
+ size_t srcChannels = 1;
AssignValues(srcShape.GetNumDimensions(),
0,
srcShape,
+ srcChannels,
srcWidth,
srcHeight,
- srcChannels,
srcBatches,
srcDepth);
size_t srcDepthStride = 0;
size_t srcBatchStride = 0;
- size_t srcChannelStride = 0;
size_t srcHeightStride = 0;
size_t srcWidthStride = 0;
+ size_t srcChannelStride = 0;
AssignValues(srcStrides.GetNumDimensions(),
0,
srcStrides,
+ srcChannelStride,
srcWidthStride,
srcHeightStride,
- srcChannelStride,
srcBatchStride,
srcDepthStride);
size_t dstDepth = 1;
size_t dstBatches = 1;
- size_t dstChannels = 1;
size_t dstHeight = 1;
size_t dstWidth = 1;
+ size_t dstChannels = 1;
AssignValues(dstShape.GetNumDimensions(),
0,
dstShape,
+ dstChannels,
dstWidth,
dstHeight,
- dstChannels,
dstBatches,
dstDepth);
size_t dstDepthStride = 0;
size_t dstBatchStride = 0;
- size_t dstChannelStride = 0;
size_t dstHeightStride = 0;
size_t dstWidthStride = 0;
+ size_t dstChannelStride = 0;
AssignValues(dstStrides.GetNumDimensions(),
0,
dstStrides,
+ dstChannelStride,
dstWidthStride,
dstHeightStride,
- dstChannelStride,
dstBatchStride,
dstDepthStride);
@@ -117,11 +119,11 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds
dstData = static_cast<uint8_t*>(dstTensor->Map());
}
- size_t copyLength = std::min(srcWidth * srcWidthStride, dstWidth * dstWidthStride);
- size_t copyHeight = std::min(srcHeight, dstHeight);
- size_t copyChannels = std::min(srcChannels, dstChannels);
- size_t copyBatches = std::min(srcBatches, dstBatches);
- size_t copyDepth = std::min(srcDepth, dstDepth);
+ size_t copyLength = std::min(srcChannels*srcChannelStride, dstChannels*dstChannelStride);
+ size_t copyWidth = std::min(srcWidth, dstWidth);
+ size_t copyHeight = std::min(srcHeight, dstHeight);
+ size_t copyBatches = std::min(srcBatches, dstBatches);
+ size_t copyDepth = std::min(srcDepth, dstDepth);
for (unsigned int d = 0; d < copyDepth; ++d)
{
@@ -131,18 +133,18 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds
{
auto srcPtrBatch = srcData;
auto dstPtrBatch = dstData;
- for (unsigned int c = 0; c < copyChannels; ++c)
+ for (unsigned int h = 0; h < copyHeight; ++h)
{
auto srcPtrChannel = srcData;
auto dstPtrChannel = dstData;
- for (unsigned int h = 0; h < copyHeight; ++h)
+ for (unsigned int w = 0; w < copyWidth; ++w)
{
copy(dstData, srcData, copyLength);
- dstData += dstHeightStride;
- srcData += srcHeightStride;
+ dstData += dstWidthStride;
+ srcData += srcWidthStride;
}
- dstData += (static_cast<long>(dstChannelStride) - (dstData - dstPtrChannel));
- srcData += (static_cast<long>(srcChannelStride) - (srcData - srcPtrChannel));
+ dstData += (static_cast<long>(dstHeightStride) - (dstData - dstPtrChannel));
+ srcData += (static_cast<long>(srcHeightStride) - (srcData - srcPtrChannel));
}
dstData += (static_cast<long>(dstBatchStride) - (dstData - dstPtrBatch));
srcData += (static_cast<long>(srcBatchStride) - (srcData - srcPtrBatch));