From efdbca6b1a25dbd491be7f308dd729ff2255bb28 Mon Sep 17 00:00:00 2001 From: Matthew Bentham Date: Sat, 14 Sep 2019 23:35:28 +0100 Subject: Rename variables in CopyTensorContents to assume NHWC Change-Id: I533991c8829256570529c18023a5e882878cc85a Signed-off-by: Matthew Bentham --- src/backends/backendsCommon/WorkloadUtils.hpp | 40 ++++++++++++++------------- 1 file changed, 21 insertions(+), 19 deletions(-) (limited to 'src/backends/backendsCommon/WorkloadUtils.hpp') diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp index ba69255183..3e0c40d890 100644 --- a/src/backends/backendsCommon/WorkloadUtils.hpp +++ b/src/backends/backendsCommon/WorkloadUtils.hpp @@ -46,6 +46,8 @@ void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& template void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy) { + // For ease of understanding, names are assigned to the dimensions + // of the tensor as if NHWC, however this routine works with any 5D tensor static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents"); TensorShape srcStrides = srcTensor->GetStrides(); @@ -55,57 +57,57 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds size_t srcDepth = 1; size_t srcBatches = 1; - size_t srcChannels = 1; size_t srcHeight = 1; size_t srcWidth = 1; + size_t srcChannels = 1; AssignValues(srcShape.GetNumDimensions(), 0, srcShape, + srcChannels, srcWidth, srcHeight, - srcChannels, srcBatches, srcDepth); size_t srcDepthStride = 0; size_t srcBatchStride = 0; - size_t srcChannelStride = 0; size_t srcHeightStride = 0; size_t srcWidthStride = 0; + size_t srcChannelStride = 0; AssignValues(srcStrides.GetNumDimensions(), 0, srcStrides, + srcChannelStride, srcWidthStride, srcHeightStride, - srcChannelStride, srcBatchStride, srcDepthStride); size_t dstDepth = 1; size_t dstBatches = 1; - size_t dstChannels = 1; size_t dstHeight = 1; size_t dstWidth = 1; + size_t dstChannels = 1; AssignValues(dstShape.GetNumDimensions(), 0, dstShape, + dstChannels, dstWidth, dstHeight, - dstChannels, dstBatches, dstDepth); size_t dstDepthStride = 0; size_t dstBatchStride = 0; - size_t dstChannelStride = 0; size_t dstHeightStride = 0; size_t dstWidthStride = 0; + size_t dstChannelStride = 0; AssignValues(dstStrides.GetNumDimensions(), 0, dstStrides, + dstChannelStride, dstWidthStride, dstHeightStride, - dstChannelStride, dstBatchStride, dstDepthStride); @@ -117,11 +119,11 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds dstData = static_cast(dstTensor->Map()); } - size_t copyLength = std::min(srcWidth * srcWidthStride, dstWidth * dstWidthStride); - size_t copyHeight = std::min(srcHeight, dstHeight); - size_t copyChannels = std::min(srcChannels, dstChannels); - size_t copyBatches = std::min(srcBatches, dstBatches); - size_t copyDepth = std::min(srcDepth, dstDepth); + size_t copyLength = std::min(srcChannels*srcChannelStride, dstChannels*dstChannelStride); + size_t copyWidth = std::min(srcWidth, dstWidth); + size_t copyHeight = std::min(srcHeight, dstHeight); + size_t copyBatches = std::min(srcBatches, dstBatches); + size_t copyDepth = std::min(srcDepth, dstDepth); for (unsigned int d = 0; d < copyDepth; ++d) { @@ -131,18 +133,18 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds { auto srcPtrBatch = srcData; auto dstPtrBatch = dstData; - for (unsigned int c = 0; c < copyChannels; ++c) + for (unsigned int h = 0; h < copyHeight; ++h) { auto srcPtrChannel = srcData; auto dstPtrChannel = dstData; - for (unsigned int h = 0; h < copyHeight; ++h) + for (unsigned int w = 0; w < copyWidth; ++w) { copy(dstData, srcData, copyLength); - dstData += dstHeightStride; - srcData += srcHeightStride; + dstData += dstWidthStride; + srcData += srcWidthStride; } - dstData += (static_cast(dstChannelStride) - (dstData - dstPtrChannel)); - srcData += (static_cast(srcChannelStride) - (srcData - srcPtrChannel)); + dstData += (static_cast(dstHeightStride) - (dstData - dstPtrChannel)); + srcData += (static_cast(srcHeightStride) - (srcData - srcPtrChannel)); } dstData += (static_cast(dstBatchStride) - (dstData - dstPtrBatch)); srcData += (static_cast(srcBatchStride) - (srcData - srcPtrBatch)); -- cgit v1.2.1