aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPatryk Kaiser <patryk.kaiser@arm.com>2024-06-27 12:29:48 +0100
committerColm Donelan <colm.donelan@arm.com>2024-07-01 16:55:55 +0000
commit656285153399d96ead5925db907d0ec1961dfd76 (patch)
tree74ae6dbd3199c0297c0321d5998cfd546537af26 /src
parent443804adee542d4330713e8dda6357b9495856fa (diff)
downloadarmnn-656285153399d96ead5925db907d0ec1961dfd76.tar.gz
IVGCVSW-8139 Fixing Broadcast OP DTS tests
* In ref broadcast layer added broadcast support to expand tensor shapes * Added function to check for zero dimension tensors * Added check for unsupported zero dimension tensors during broadcast * Added DelegateUtils unit test file with unit tests for the new function Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com> Change-Id: If4e786f7ba580399e781c48335888e8da8458019
Diffstat (limited to 'src')
-rw-r--r--src/backends/reference/workloads/Broadcast.cpp24
1 files changed, 21 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp
index 24af0fc4b1..f17ec6b311 100644
--- a/src/backends/reference/workloads/Broadcast.cpp
+++ b/src/backends/reference/workloads/Broadcast.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019 Arm Ltd. All rights reserved.
+// Copyright © 2019,2024 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -38,13 +38,31 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outS
unsigned int sIn = 1;
unsigned int sOut = 1;
+ // Get the difference between the output dimension and input dimension
+ const unsigned int dimDifference = numDims - inShape.GetNumDimensions();
+
for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--)
{
+
m_DimData[j].m_DimSize = outShape[j];
- m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ // Pretend there are extra 1-dimensional tensors prepended
+ if (dimDifference > 0 && j < dimDifference)
+ {
+ m_DimData[j].m_Stride1 = 0;
+ sIn *= 1;
+ }
+ else if (dimDifference > 0)
+ {
+ m_DimData[j].m_Stride1 = (inShape[j - dimDifference] > 1) ? sIn : 0;
+ sIn *= inShape[j - dimDifference];
+ }
+ else
+ {
+ m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ sIn *= inShape[j];
+ }
m_DimData[j].m_StrideOut = sOut;
- sIn *= inShape[j];
sOut *= outShape[j];
}
}