aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2018-10-01 11:32:48 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commite2ec330b015fc34331f0023283b964ef97e1c5bb (patch)
tree76ce917f5d243ceb473dcdc2263e09781954b635
parentb30c53342756699dfaa6effd0a94aa9e69c2063a (diff)
downloadarmnn-e2ec330b015fc34331f0023283b964ef97e1c5bb.tar.gz
IVGCVSW-1207 - Remove typing from ClMultiplicationWorkload
Don't need this now as it uses the compute library validation function, and all of the code for the supported types is identical. Adds Uint8 support to Cl backend, and unit test cases. Change-Id: I35d4edacc1aca241e95d1b19ae525a23d9513c99
-rw-r--r--src/backends/cl/ClLayerSupport.cpp2
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp2
-rw-r--r--src/backends/cl/backend.mk2
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt4
-rw-r--r--src/backends/cl/workloads/ClMultiplicationWorkload.cpp (renamed from src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp)14
-rw-r--r--src/backends/cl/workloads/ClMultiplicationWorkload.hpp (renamed from src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp)6
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp2
-rw-r--r--src/backends/test/ArmComputeCl.cpp3
-rw-r--r--src/backends/test/CreateWorkloadCl.cpp12
9 files changed, 29 insertions, 18 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 6c1940b02f..a17997b184 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -23,7 +23,7 @@
#include "workloads/ClDepthwiseConvolutionBaseWorkload.hpp"
#include "workloads/ClDivisionFloatWorkload.hpp"
#include "workloads/ClL2NormalizationFloatWorkload.hpp"
-#include "workloads/ClMultiplicationFloatWorkload.hpp"
+#include "workloads/ClMultiplicationWorkload.hpp"
#include "workloads/ClFullyConnectedWorkload.hpp"
#include "workloads/ClPadWorkload.hpp"
#include "workloads/ClPooling2dBaseWorkload.hpp"
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 46a96559bf..685696c502 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -170,7 +170,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateAddition(const Additi
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMultiplication(
const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<ClMultiplicationFloatWorkload, ClMultiplicationFloatWorkload>(descriptor, info);
+ return std::make_unique<ClMultiplicationWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision(
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index 2418a24249..f54a8deff7 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -33,7 +33,7 @@ BACKEND_SOURCES := \
workloads/ClLstmFloatWorkload.cpp \
workloads/ClMergerFloatWorkload.cpp \
workloads/ClMergerUint8Workload.cpp \
- workloads/ClMultiplicationFloatWorkload.cpp \
+ workloads/ClMultiplicationWorkload.cpp \
workloads/ClNormalizationFloatWorkload.cpp \
workloads/ClPadWorkload.cpp \
workloads/ClPermuteWorkload.cpp \
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index 066c37f083..959d3e25df 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -50,8 +50,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClMergerFloatWorkload.hpp
ClMergerUint8Workload.cpp
ClMergerUint8Workload.hpp
- ClMultiplicationFloatWorkload.cpp
- ClMultiplicationFloatWorkload.hpp
+ ClMultiplicationWorkload.cpp
+ ClMultiplicationWorkload.hpp
ClNormalizationFloatWorkload.cpp
ClNormalizationFloatWorkload.hpp
ClPadWorkload.cpp
diff --git a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp
index d53e149129..9d23caa695 100644
--- a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp
@@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT
//
-#include "ClMultiplicationFloatWorkload.hpp"
+#include "ClMultiplicationWorkload.hpp"
#include <backends/cl/ClTensorHandle.hpp>
#include <backends/CpuTensorHandle.hpp>
#include "ClWorkloadUtils.hpp"
@@ -31,11 +31,11 @@ arm_compute::Status ClMultiplicationWorkloadValidate(const TensorInfo& input0,
}
-ClMultiplicationFloatWorkload::ClMultiplicationFloatWorkload(const MultiplicationQueueDescriptor& descriptor,
- const WorkloadInfo& info)
- : FloatWorkload<MultiplicationQueueDescriptor>(descriptor, info)
+ClMultiplicationWorkload::ClMultiplicationWorkload(const MultiplicationQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<MultiplicationQueueDescriptor>(descriptor, info)
{
- m_Data.ValidateInputsOutputs("ClMultiplicationFloatWorkload", 2, 1);
+ m_Data.ValidateInputsOutputs("ClMultiplicationWorkload", 2, 1);
arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
@@ -49,9 +49,9 @@ ClMultiplicationFloatWorkload::ClMultiplicationFloatWorkload(const Multiplicatio
arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
}
-void ClMultiplicationFloatWorkload::Execute() const
+void ClMultiplicationWorkload::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT_CL("ClMultiplicationFloatWorkload_Execute");
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClMultiplicationWorkload_Execute");
// Executes the layer.
m_PixelWiseMultiplication.run();
diff --git a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp b/src/backends/cl/workloads/ClMultiplicationWorkload.hpp
index a793ac64df..0586be96ed 100644
--- a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClMultiplicationWorkload.hpp
@@ -16,12 +16,12 @@ arm_compute::Status ClMultiplicationWorkloadValidate(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output);
-class ClMultiplicationFloatWorkload : public FloatWorkload<MultiplicationQueueDescriptor>
+class ClMultiplicationWorkload : public BaseWorkload<MultiplicationQueueDescriptor>
{
public:
- ClMultiplicationFloatWorkload(const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info);
+ ClMultiplicationWorkload(const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info);
- using FloatWorkload<MultiplicationQueueDescriptor>::FloatWorkload;
+ using BaseWorkload<MultiplicationQueueDescriptor>::BaseWorkload;
void Execute() const override;
private:
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index 3329f42e08..c0625f6791 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -23,7 +23,7 @@
#include "ClLstmFloatWorkload.hpp"
#include "ClMergerFloatWorkload.hpp"
#include "ClMergerUint8Workload.hpp"
-#include "ClMultiplicationFloatWorkload.hpp"
+#include "ClMultiplicationWorkload.hpp"
#include "ClNormalizationFloatWorkload.hpp"
#include "ClPermuteWorkload.hpp"
#include "ClPadWorkload.hpp"
diff --git a/src/backends/test/ArmComputeCl.cpp b/src/backends/test/ArmComputeCl.cpp
index 4f1a84dfad..d83f812cd0 100644
--- a/src/backends/test/ArmComputeCl.cpp
+++ b/src/backends/test/ArmComputeCl.cpp
@@ -162,6 +162,9 @@ ARMNN_AUTO_TEST_CASE(DivisionBroadcast1DVector, DivisionBroadcast1DVectorTest)
ARMNN_AUTO_TEST_CASE(SimpleMultiplication, MultiplicationTest)
ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1Element, MultiplicationBroadcast1ElementTest)
ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVector, MultiplicationBroadcast1DVectorTest)
+ARMNN_AUTO_TEST_CASE(MultiplicationUint8, MultiplicationUint8Test)
+ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1ElementUint8, MultiplicationBroadcast1ElementUint8Test)
+ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVectorUint8, MultiplicationBroadcast1DVectorUint8Test)
// Batch Norm
ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest)
diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp
index 078ef8c52d..e48cd97d6f 100644
--- a/src/backends/test/CreateWorkloadCl.cpp
+++ b/src/backends/test/CreateWorkloadCl.cpp
@@ -101,7 +101,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
{
- ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::Float32>();
@@ -109,12 +109,20 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest)
{
- ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::Float16>();
}
+BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8WorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::QuantisedAsymm8>();
+}
+
BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest)
{
ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload,