aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp')
-rw-r--r--src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
index 8e0a617bf5..b929a53808 100644
--- a/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
@@ -5,7 +5,7 @@
#include "RefMultiplicationUint8Workload.hpp"
-#include "Multiplication.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -27,9 +27,13 @@ void RefMultiplicationUint8Workload::Execute() const
auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
std::vector<float> results(outputInfo.GetNumElements());
- Multiplication(
- inputInfo0.GetShape(), inputInfo1.GetShape(), outputInfo.GetShape(),
- dequant0.data(), dequant1.data(),results.data());
+
+ ArithmeticFunction<std::multiplies<float>>(inputInfo0.GetShape(),
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
}