diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.hpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadUtils.hpp | 54 |
1 files changed, 35 insertions, 19 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp index a1a8d2a475..7e3ac395e4 100644 --- a/src/backends/backendsCommon/WorkloadUtils.hpp +++ b/src/backends/backendsCommon/WorkloadUtils.hpp @@ -46,13 +46,14 @@ void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& template<typename CopyFunc> void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy) { - static_assert(MaxNumOfTensorDimensions == 4, "Please update CopyTensorContents"); + static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents"); TensorShape srcStrides = srcTensor->GetStrides(); const TensorShape& srcShape = srcTensor->GetShape(); TensorShape dstStrides = dstTensor->GetStrides(); const TensorShape& dstShape = dstTensor->GetShape(); + size_t srcDepth = 1; size_t srcBatches = 1; size_t srcChannels = 1; size_t srcHeight = 1; @@ -61,8 +62,10 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds srcWidth, srcHeight, srcChannels, - srcBatches); + srcBatches, + srcDepth); + size_t srcDepthStride = 0; size_t srcBatchStride = 0; size_t srcChannelStride = 0; size_t srcHeightStride = 0; @@ -71,8 +74,10 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds srcWidthStride, srcHeightStride, srcChannelStride, - srcBatchStride); + srcBatchStride, + srcDepthStride); + size_t dstDepth = 1; size_t dstBatches = 1; size_t dstChannels = 1; size_t dstHeight = 1; @@ -81,8 +86,10 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds dstWidth, dstHeight, dstChannels, - dstBatches); + dstBatches, + dstDepth); + size_t dstDepthStride = 0; size_t dstBatchStride = 0; size_t dstChannelStride = 0; size_t dstHeightStride = 0; @@ -91,7 +98,8 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds dstWidthStride, dstHeightStride, dstChannelStride, - dstBatchStride); + dstBatchStride, + dstDepthStride); const unsigned char* srcData; unsigned char* dstData; @@ -105,26 +113,34 @@ void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* ds size_t copyHeight = std::min(srcHeight, dstHeight); size_t copyChannels = std::min(srcChannels, dstChannels); size_t copyBatches = std::min(srcBatches, dstBatches); + size_t copyDepth = std::min(srcDepth, dstDepth); - for(unsigned int b=0; b < copyBatches; ++b) + for (unsigned int d=0; d < copyDepth; ++d) { - auto srcPtrBatch = srcData; - auto dstPtrBatch = dstData; - for (unsigned int c=0; c< copyChannels; ++c) + auto srcPtrDepth = srcData; + auto dstPtrDepth = dstData; + for (unsigned int b=0; b < copyBatches; ++b) { - auto srcPtrChannel = srcData; - auto dstPtrChannel = dstData; - for (unsigned int h=0; h < copyHeight; ++h) + auto srcPtrBatch = srcData; + auto dstPtrBatch = dstData; + for (unsigned int c=0; c< copyChannels; ++c) { - copy(dstData, srcData, copyLength); - dstData += dstHeightStride; - srcData += srcHeightStride; + auto srcPtrChannel = srcData; + auto dstPtrChannel = dstData; + for (unsigned int h=0; h < copyHeight; ++h) + { + copy(dstData, srcData, copyLength); + dstData += dstHeightStride; + srcData += srcHeightStride; + } + dstData += (static_cast<long>(dstChannelStride) - (dstData - dstPtrChannel)); + srcData += (static_cast<long>(srcChannelStride) - (srcData - srcPtrChannel)); } - dstData += (static_cast<long>(dstChannelStride) - (dstData - dstPtrChannel)); - srcData += (static_cast<long>(srcChannelStride) - (srcData - srcPtrChannel)); + dstData += (static_cast<long>(dstBatchStride)-(dstData - dstPtrBatch)); + srcData += (static_cast<long>(srcBatchStride)-(srcData - srcPtrBatch)); } - dstData += (static_cast<long>(dstBatchStride)-(dstData - dstPtrBatch)); - srcData += (static_cast<long>(srcBatchStride)-(srcData - srcPtrBatch)); + dstData += (static_cast<long>(dstDepthStride)-(dstData - dstPtrDepth)); + srcData += (static_cast<long>(srcDepthStride)-(srcData - srcPtrDepth)); } srcTensor->Unmap(); |