From dba634fd6a66a9e033a1925b0b26c80b270bbf21 Mon Sep 17 00:00:00 2001 From: Matthew Jackson Date: Thu, 15 Aug 2019 15:14:18 +0100 Subject: IVGCVSW-3639 Add 5d tensor support * Increased MaxNumOfTensorDimensions and fixed issues related to its use * Fixed issues caused by assuming 5d tensors are invalid * Updated ArmComputeTensorUtils for 5d tensors * Added 5d tensor unit tests for add, mul, stack and reshape (needed by IVGCVSW-3527) Signed-off-by: Matthew Jackson Change-Id: I5bcd64942d0d04efcc6c5acb240ad4b88e010743 --- src/backends/backendsCommon/WorkloadUtils.hpp | 54 +++++++++++++++++---------- 1 file changed, 35 insertions(+), 19 deletions(-) (limited to 'src/backends/backendsCommon/WorkloadUtils.hpp') diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp index a1a8d2a475..7e3ac395e4 100644 --- a/src/backends/backendsCommon/WorkloadUtils.hpp +++ b/src/backends/backendsCommon/WorkloadUtils.hpp @@ -46,13 +46,14 @@ void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& template void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy) { - static_assert(MaxNumOfTensorDimensions == 4, "Please update CopyTensorContents"); + static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents"); TensorShape srcStrides = srcTensor->GetStrides(); const TensorShape& srcShape = srcTensor->GetShape(); TensorShape dstStrides = dstTensor->GetStrides(); const TensorShape& dstShape = dstTensor->GetShape(); + size_t srcDepth = 1; size_t srcBatches = 1; size_t srcChannels = 1; size_t srcHeight = 1; @@ -61,8 +62,10 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds srcWidth, srcHeight, srcChannels, - srcBatches); + srcBatches, + srcDepth); + size_t srcDepthStride = 0; size_t srcBatchStride = 0; size_t srcChannelStride = 0; size_t srcHeightStride = 0; @@ -71,8 +74,10 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds srcWidthStride, srcHeightStride, srcChannelStride, - srcBatchStride); + srcBatchStride, + srcDepthStride); + size_t dstDepth = 1; size_t dstBatches = 1; size_t dstChannels = 1; size_t dstHeight = 1; @@ -81,8 +86,10 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds dstWidth, dstHeight, dstChannels, - dstBatches); + dstBatches, + dstDepth); + size_t dstDepthStride = 0; size_t dstBatchStride = 0; size_t dstChannelStride = 0; size_t dstHeightStride = 0; @@ -91,7 +98,8 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds dstWidthStride, dstHeightStride, dstChannelStride, - dstBatchStride); + dstBatchStride, + dstDepthStride); const unsigned char* srcData; unsigned char* dstData; @@ -105,26 +113,34 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds size_t copyHeight = std::min(srcHeight, dstHeight); size_t copyChannels = std::min(srcChannels, dstChannels); size_t copyBatches = std::min(srcBatches, dstBatches); + size_t copyDepth = std::min(srcDepth, dstDepth); - for(unsigned int b=0; b < copyBatches; ++b) + for (unsigned int d=0; d < copyDepth; ++d) { - auto srcPtrBatch = srcData; - auto dstPtrBatch = dstData; - for (unsigned int c=0; c< copyChannels; ++c) + auto srcPtrDepth = srcData; + auto dstPtrDepth = dstData; + for (unsigned int b=0; b < copyBatches; ++b) { - auto srcPtrChannel = srcData; - auto dstPtrChannel = dstData; - for (unsigned int h=0; h < copyHeight; ++h) + auto srcPtrBatch = srcData; + auto dstPtrBatch = dstData; + for (unsigned int c=0; c< copyChannels; ++c) { - copy(dstData, srcData, copyLength); - dstData += dstHeightStride; - srcData += srcHeightStride; + auto srcPtrChannel = srcData; + auto dstPtrChannel = dstData; + for (unsigned int h=0; h < copyHeight; ++h) + { + copy(dstData, srcData, copyLength); + dstData += dstHeightStride; + srcData += srcHeightStride; + } + dstData += (static_cast(dstChannelStride) - (dstData - dstPtrChannel)); + srcData += (static_cast(srcChannelStride) - (srcData - srcPtrChannel)); } - dstData += (static_cast(dstChannelStride) - (dstData - dstPtrChannel)); - srcData += (static_cast(srcChannelStride) - (srcData - srcPtrChannel)); + dstData += (static_cast(dstBatchStride)-(dstData - dstPtrBatch)); + srcData += (static_cast(srcBatchStride)-(srcData - srcPtrBatch)); } - dstData += (static_cast(dstBatchStride)-(dstData - dstPtrBatch)); - srcData += (static_cast(srcBatchStride)-(srcData - srcPtrBatch)); + dstData += (static_cast(dstDepthStride)-(dstData - dstPtrDepth)); + srcData += (static_cast(srcDepthStride)-(srcData - srcPtrDepth)); } srcTensor->Unmap(); -- cgit v1.2.1