aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Broadcast.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Broadcast.cpp')
-rw-r--r--src/backends/reference/workloads/Broadcast.cpp21
1 files changed, 20 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp
index 8421a0a7ed..24af0fc4b1 100644
--- a/src/backends/reference/workloads/Broadcast.cpp
+++ b/src/backends/reference/workloads/Broadcast.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -30,4 +30,23 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape0, const TensorShape& inS
}
}
+BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape)
+: m_DimData(outShape.GetNumDimensions())
+{
+ const unsigned int numDims = GetNumDimensions();
+
+ unsigned int sIn = 1;
+ unsigned int sOut = 1;
+
+ for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--)
+ {
+ m_DimData[j].m_DimSize = outShape[j];
+ m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ m_DimData[j].m_StrideOut = sOut;
+
+ sIn *= inShape[j];
+ sOut *= outShape[j];
+ }
+}
+
} // namespace armnn