aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Broadcast.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Broadcast.cpp')
-rw-r--r--src/backends/reference/workloads/Broadcast.cpp24
1 files changed, 21 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp
index 24af0fc4b1..f17ec6b311 100644
--- a/src/backends/reference/workloads/Broadcast.cpp
+++ b/src/backends/reference/workloads/Broadcast.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019 Arm Ltd. All rights reserved.
+// Copyright © 2019,2024 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -38,13 +38,31 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outS
unsigned int sIn = 1;
unsigned int sOut = 1;
+ // Get the difference between the output dimension and input dimension
+ const unsigned int dimDifference = numDims - inShape.GetNumDimensions();
+
for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--)
{
+
m_DimData[j].m_DimSize = outShape[j];
- m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ // Pretend there are extra 1-dimensional tensors prepended
+ if (dimDifference > 0 && j < dimDifference)
+ {
+ m_DimData[j].m_Stride1 = 0;
+ sIn *= 1;
+ }
+ else if (dimDifference > 0)
+ {
+ m_DimData[j].m_Stride1 = (inShape[j - dimDifference] > 1) ? sIn : 0;
+ sIn *= inShape[j - dimDifference];
+ }
+ else
+ {
+ m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ sIn *= inShape[j];
+ }
m_DimData[j].m_StrideOut = sOut;
- sIn *= inShape[j];
sOut *= outShape[j];
}
}