aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads/Multiplication.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloads/Multiplication.cpp')
-rw-r--r--src/armnn/backends/RefWorkloads/Multiplication.cpp42
1 files changed, 36 insertions, 6 deletions
diff --git a/src/armnn/backends/RefWorkloads/Multiplication.cpp b/src/armnn/backends/RefWorkloads/Multiplication.cpp
index 7f558d83c5..47c0f1cef1 100644
--- a/src/armnn/backends/RefWorkloads/Multiplication.cpp
+++ b/src/armnn/backends/RefWorkloads/Multiplication.cpp
@@ -4,18 +4,48 @@
//
#include "Multiplication.hpp"
+#include "Broadcast.hpp"
-namespace armnn
+#include <functional>
+
+namespace
{
-void Multiplication(const float* in0,
- const float* in1,
- unsigned int numElements,
- float* out)
+void ElementwiseMultiplication(unsigned int numElements,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
{
for (unsigned int i = 0; i < numElements; ++i)
{
- out[i] = in0[i] * in1[i];
+ outData[i] = inData0[i] * inData1[i];
+ }
+}
+
+} // namespace
+
+namespace armnn
+{
+
+void Multiplication(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
+{
+ if (inShape0 == inShape1)
+ {
+ ElementwiseMultiplication(inShape0.GetNumElements(), inData0, inData1, outData);
+ }
+ else
+ {
+ BroadcastLoop(inShape0, inShape1, outShape).Unroll(
+ std::multiplies<float>(),
+ 0,
+ inData0,
+ inData1,
+ outData);
}
}