diff options
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/reference/workloads/Broadcast.cpp | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp index 24af0fc4b1..f17ec6b311 100644 --- a/src/backends/reference/workloads/Broadcast.cpp +++ b/src/backends/reference/workloads/Broadcast.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2019 Arm Ltd. All rights reserved. +// Copyright © 2019,2024 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -38,13 +38,31 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outS unsigned int sIn = 1; unsigned int sOut = 1; + // Get the difference between the output dimension and input dimension + const unsigned int dimDifference = numDims - inShape.GetNumDimensions(); + for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--) { + m_DimData[j].m_DimSize = outShape[j]; - m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0; + // Pretend there are extra 1-dimensional tensors prepended + if (dimDifference > 0 && j < dimDifference) + { + m_DimData[j].m_Stride1 = 0; + sIn *= 1; + } + else if (dimDifference > 0) + { + m_DimData[j].m_Stride1 = (inShape[j - dimDifference] > 1) ? sIn : 0; + sIn *= inShape[j - dimDifference]; + } + else + { + m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0; + sIn *= inShape[j]; + } m_DimData[j].m_StrideOut = sOut; - sIn *= inShape[j]; sOut *= outShape[j]; } } |