diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/backend.mk | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClMultiplicationWorkload.cpp (renamed from src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp) | 14 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClMultiplicationWorkload.hpp (renamed from src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp) | 6 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 2 | ||||
-rw-r--r-- | src/backends/test/ArmComputeCl.cpp | 3 | ||||
-rw-r--r-- | src/backends/test/CreateWorkloadCl.cpp | 12 |
9 files changed, 29 insertions, 18 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 6c1940b02f..a17997b184 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -23,7 +23,7 @@ #include "workloads/ClDepthwiseConvolutionBaseWorkload.hpp" #include "workloads/ClDivisionFloatWorkload.hpp" #include "workloads/ClL2NormalizationFloatWorkload.hpp" -#include "workloads/ClMultiplicationFloatWorkload.hpp" +#include "workloads/ClMultiplicationWorkload.hpp" #include "workloads/ClFullyConnectedWorkload.hpp" #include "workloads/ClPadWorkload.hpp" #include "workloads/ClPooling2dBaseWorkload.hpp" diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 46a96559bf..685696c502 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -170,7 +170,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateAddition(const Additi std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMultiplication( const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<ClMultiplicationFloatWorkload, ClMultiplicationFloatWorkload>(descriptor, info); + return std::make_unique<ClMultiplicationWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision( diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index 2418a24249..f54a8deff7 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -33,7 +33,7 @@ BACKEND_SOURCES := \ workloads/ClLstmFloatWorkload.cpp \ workloads/ClMergerFloatWorkload.cpp \ workloads/ClMergerUint8Workload.cpp \ - workloads/ClMultiplicationFloatWorkload.cpp \ + workloads/ClMultiplicationWorkload.cpp \ workloads/ClNormalizationFloatWorkload.cpp \ workloads/ClPadWorkload.cpp \ workloads/ClPermuteWorkload.cpp \ diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 066c37f083..959d3e25df 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -50,8 +50,8 @@ list(APPEND armnnClBackendWorkloads_sources ClMergerFloatWorkload.hpp ClMergerUint8Workload.cpp ClMergerUint8Workload.hpp - ClMultiplicationFloatWorkload.cpp - ClMultiplicationFloatWorkload.hpp + ClMultiplicationWorkload.cpp + ClMultiplicationWorkload.hpp ClNormalizationFloatWorkload.cpp ClNormalizationFloatWorkload.hpp ClPadWorkload.cpp diff --git a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp index d53e149129..9d23caa695 100644 --- a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT // -#include "ClMultiplicationFloatWorkload.hpp" +#include "ClMultiplicationWorkload.hpp" #include <backends/cl/ClTensorHandle.hpp> #include <backends/CpuTensorHandle.hpp> #include "ClWorkloadUtils.hpp" @@ -31,11 +31,11 @@ arm_compute::Status ClMultiplicationWorkloadValidate(const TensorInfo& input0, } -ClMultiplicationFloatWorkload::ClMultiplicationFloatWorkload(const MultiplicationQueueDescriptor& descriptor, - const WorkloadInfo& info) - : FloatWorkload<MultiplicationQueueDescriptor>(descriptor, info) +ClMultiplicationWorkload::ClMultiplicationWorkload(const MultiplicationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<MultiplicationQueueDescriptor>(descriptor, info) { - m_Data.ValidateInputsOutputs("ClMultiplicationFloatWorkload", 2, 1); + m_Data.ValidateInputsOutputs("ClMultiplicationWorkload", 2, 1); arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); @@ -49,9 +49,9 @@ ClMultiplicationFloatWorkload::ClMultiplicationFloatWorkload(const Multiplicatio arm_compute::RoundingPolicy::TO_NEAREST_EVEN); } -void ClMultiplicationFloatWorkload::Execute() const +void ClMultiplicationWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT_CL("ClMultiplicationFloatWorkload_Execute"); + ARMNN_SCOPED_PROFILING_EVENT_CL("ClMultiplicationWorkload_Execute"); // Executes the layer. m_PixelWiseMultiplication.run(); diff --git a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp b/src/backends/cl/workloads/ClMultiplicationWorkload.hpp index a793ac64df..0586be96ed 100644 --- a/src/backends/cl/workloads/ClMultiplicationFloatWorkload.hpp +++ b/src/backends/cl/workloads/ClMultiplicationWorkload.hpp @@ -16,12 +16,12 @@ arm_compute::Status ClMultiplicationWorkloadValidate(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output); -class ClMultiplicationFloatWorkload : public FloatWorkload<MultiplicationQueueDescriptor> +class ClMultiplicationWorkload : public BaseWorkload<MultiplicationQueueDescriptor> { public: - ClMultiplicationFloatWorkload(const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info); + ClMultiplicationWorkload(const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info); - using FloatWorkload<MultiplicationQueueDescriptor>::FloatWorkload; + using BaseWorkload<MultiplicationQueueDescriptor>::BaseWorkload; void Execute() const override; private: diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 3329f42e08..c0625f6791 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -23,7 +23,7 @@ #include "ClLstmFloatWorkload.hpp" #include "ClMergerFloatWorkload.hpp" #include "ClMergerUint8Workload.hpp" -#include "ClMultiplicationFloatWorkload.hpp" +#include "ClMultiplicationWorkload.hpp" #include "ClNormalizationFloatWorkload.hpp" #include "ClPermuteWorkload.hpp" #include "ClPadWorkload.hpp" diff --git a/src/backends/test/ArmComputeCl.cpp b/src/backends/test/ArmComputeCl.cpp index 4f1a84dfad..d83f812cd0 100644 --- a/src/backends/test/ArmComputeCl.cpp +++ b/src/backends/test/ArmComputeCl.cpp @@ -162,6 +162,9 @@ ARMNN_AUTO_TEST_CASE(DivisionBroadcast1DVector, DivisionBroadcast1DVectorTest) ARMNN_AUTO_TEST_CASE(SimpleMultiplication, MultiplicationTest) ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1Element, MultiplicationBroadcast1ElementTest) ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVector, MultiplicationBroadcast1DVectorTest) +ARMNN_AUTO_TEST_CASE(MultiplicationUint8, MultiplicationUint8Test) +ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1ElementUint8, MultiplicationBroadcast1ElementUint8Test) +ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVectorUint8, MultiplicationBroadcast1DVectorUint8Test) // Batch Norm ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest) diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp index 078ef8c52d..e48cd97d6f 100644 --- a/src/backends/test/CreateWorkloadCl.cpp +++ b/src/backends/test/CreateWorkloadCl.cpp @@ -101,7 +101,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest) { - ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload, + ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload, MultiplicationQueueDescriptor, MultiplicationLayer, armnn::DataType::Float32>(); @@ -109,12 +109,20 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest) BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest) { - ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload, + ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload, MultiplicationQueueDescriptor, MultiplicationLayer, armnn::DataType::Float16>(); } +BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8WorkloadTest) +{ + ClCreateArithmethicWorkloadTest<ClMultiplicationWorkload, + MultiplicationQueueDescriptor, + MultiplicationLayer, + armnn::DataType::QuantisedAsymm8>(); +} + BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest) { ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload, |