aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/cl/ClLayerSupport.cpp2
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp2
-rw-r--r--src/backends/cl/backend.mk2
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt4
-rw-r--r--src/backends/cl/workloads/ClMultiplicationWorkload.cpp (renamed from src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp)14
-rw-r--r--src/backends/cl/workloads/ClMultiplicationWorkload.hpp (renamed from src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp)6
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp2
-rw-r--r--src/backends/test/ArmComputeCl.cpp3
-rw-r--r--src/backends/test/CreateWorkloadCl.cpp12
9 files changed, 29 insertions, 18 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 6c1940b02f..a17997b184 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -23,7 +23,7 @@
#include "workloads/ClDepthwiseConvolutionBaseWorkload.hpp"
#include "workloads/ClDivisionFloatWorkload.hpp"
#include "workloads/ClL2NormalizationFloatWorkload.hpp"
-#include "workloads/ClMultiplicationFloatWorkload.hpp"
+#include "workloads/ClMultiplicationWorkload.hpp"
#include "workloads/ClFullyConnectedWorkload.hpp"
#include "workloads/ClPadWorkload.hpp"
#include "workloads/ClPooling2dBaseWorkload.hpp"
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 46a96559bf..685696c502 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -170,7 +170,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateAddition(const Additi
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMultiplication(
const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<ClMultiplicationFloatWorkload, ClMultiplicationFloatWorkload>(descriptor, info);
+ return std::make_unique<ClMultiplicationWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision(
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index 2418a24249..f54a8deff7 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -33,7 +33,7 @@ BACKEND_SOURCES := \
workloads/ClLstmFloatWorkload.cpp \
workloads/ClMergerFloatWorkload.cpp \
workloads/ClMergerUint8Workload.cpp \
- workloads/ClMultiplicationFloatWorkload.cpp \
+ workloads/ClMultiplicationWorkload.cpp \
workloads/ClNormalizationFloatWorkload.cpp \
workloads/ClPadWorkload.cpp \
workloads/ClPermuteWorkload.cpp \
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index 066c37f083..959d3e25df 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -50,8 +50,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClMergerFloatWorkload.hpp
ClMergerUint8Workload.cpp
ClMergerUint8Workload.hpp
- ClMultiplicationFloatWorkload.cpp
- ClMultiplicationFloatWorkload.hpp
+ ClMultiplicationWorkload.cpp
+ ClMultiplicationWorkload.hpp
ClNormalizationFloatWorkload.cpp
ClNormalizationFloatWorkload.hpp
ClPadWorkload.cpp
diff --git a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp
index d53e149129..9d23caa695 100644
--- a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp
@@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT
//
-#include "ClMultiplicationFloatWorkload.hpp"
+#include "ClMultiplicationWorkload.hpp"
#include <backends/cl/ClTensorHandle.hpp>
#include <backends/CpuTensorHandle.hpp>
#include "ClWorkloadUtils.hpp"
@@ -31,11 +31,11 @@ arm_compute::Status ClMultiplicationWorkloadValidate(const TensorInfo& input0,
}
-ClMultiplicationFloatWorkload::ClMultiplicationFloatWorkload(const MultiplicationQueueDescriptor& descriptor,
- const WorkloadInfo& info)
- : FloatWorkload<MultiplicationQueueDescriptor>(descriptor, info)
+ClMultiplicationWorkload::ClMultiplicationWorkload(const MultiplicationQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<MultiplicationQueueDescriptor>(descriptor, info)
{
- m_Data.ValidateInputsOutputs("ClMultiplicationFloatWorkload", 2, 1);
+ m_Data.ValidateInputsOutputs("ClMultiplicationWorkload", 2, 1);
arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
@@ -49,9 +49,9 @@ ClMultiplicationFloatWorkload::ClMultiplicationFloatWorkload(const Multiplicatio
arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
}
-void ClMultiplicationFloatWorkload::Execute() const
+void ClMultiplicationWorkload::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT_CL("ClMultiplicationFloatWorkload_Execute");
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClMultiplicationWorkload_Execute");
// Executes the layer.
m_PixelWiseMultiplication.run();
diff --git a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp b/src/backends/cl/workloads/ClMultiplicationWorkload.hpp
index a793ac64df..0586be96ed 100644
--- a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClMultiplicationWorkload.hpp
@@ -16,12 +16,12 @@ arm_compute::Status ClMultiplicationWorkloadValidate(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output);
-class ClMultiplicationFloatWorkload : public FloatWorkload<MultiplicationQueueDescriptor>
+class ClMultiplicationWorkload : public BaseWorkload<MultiplicationQueueDescriptor>
{
public:
- ClMultiplicationFloatWorkload(const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info);
+ ClMultiplicationWorkload(const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info);
- using FloatWorkload<MultiplicationQueueDescriptor>::FloatWorkload;
+ using BaseWorkload<MultiplicationQueueDescriptor>::BaseWorkload;
void Execute() const override;
private:
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index 3329f42e08..c0625f6791 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -23,7 +23,7 @@
#include "ClLstmFloatWorkload.hpp"
#include "ClMergerFloatWorkload.hpp"
#include "ClMergerUint8Workload.hpp"
-#include "ClMultiplicationFloatWorkload.hpp"
+#include "ClMultiplicationWorkload.hpp"
#include "ClNormalizationFloatWorkload.hpp"
#include "ClPermuteWorkload.hpp"
#include "ClPadWorkload.hpp"
diff --git a/src/backends/test/ArmComputeCl.cpp b/src/backends/test/ArmComputeCl.cpp
index 4f1a84dfad..d83f812cd0 100644
--- a/src/backends/test/ArmComputeCl.cpp
+++ b/src/backends/test/ArmComputeCl.cpp
@@ -162,6 +162,9 @@ ARMNN_AUTO_TEST_CASE(DivisionBroadcast1DVector, DivisionBroadcast1DVectorTest)
ARMNN_AUTO_TEST_CASE(SimpleMultiplication, MultiplicationTest)
ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1Element, MultiplicationBroadcast1ElementTest)
ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVector, MultiplicationBroadcast1DVectorTest)
+ARMNN_AUTO_TEST_CASE(MultiplicationUint8, MultiplicationUint8Test)
+ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1ElementUint8, MultiplicationBroadcast1ElementUint8Test)
+ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVectorUint8, MultiplicationBroadcast1DVectorUint8Test)
// Batch Norm
ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest)
diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp
index 078ef8c52d..e48cd97d6f 100644
--- a/src/backends/test/CreateWorkloadCl.cpp
+++ b/src/backends/test/CreateWorkloadCl.cpp
@@ -101,7 +101,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
{
- ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::Float32>();
@@ -109,12 +109,20 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest)
{
- ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::Float16>();
}
+BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8WorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::QuantisedAsymm8>();
+}
+
BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest)
{
ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload,