diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 3 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClSubtractionWorkload.cpp | 13 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClSubtractionWorkload.hpp | 3 | ||||
-rw-r--r-- | src/backends/test/CreateWorkloadCl.cpp | 4 |
4 files changed, 8 insertions, 15 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 685696c502..0b7e539202 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -182,8 +182,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision( std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>, - ClSubtractionWorkload<armnn::DataType::QuantisedAsymm8>>(descriptor, info); + return std::make_unique<ClSubtractionWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateBatchNormalization( diff --git a/src/backends/cl/workloads/ClSubtractionWorkload.cpp b/src/backends/cl/workloads/ClSubtractionWorkload.cpp index 37b334d94e..8efed94293 100644 --- a/src/backends/cl/workloads/ClSubtractionWorkload.cpp +++ b/src/backends/cl/workloads/ClSubtractionWorkload.cpp @@ -17,10 +17,9 @@ using namespace armcomputetensorutils; static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE; -template <armnn::DataType... T> -ClSubtractionWorkload<T...>::ClSubtractionWorkload(const SubtractionQueueDescriptor& descriptor, - const WorkloadInfo& info) - : TypedWorkload<SubtractionQueueDescriptor, T...>(descriptor, info) +ClSubtractionWorkload::ClSubtractionWorkload(const SubtractionQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<SubtractionQueueDescriptor>(descriptor, info) { this->m_Data.ValidateInputsOutputs("ClSubtractionWorkload", 2, 1); @@ -30,8 +29,7 @@ ClSubtractionWorkload<T...>::ClSubtractionWorkload(const SubtractionQueueDescrip m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy); } -template <armnn::DataType... T> -void ClSubtractionWorkload<T...>::Execute() const +void ClSubtractionWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionWorkload_Execute"); m_Layer.run(); @@ -61,6 +59,3 @@ bool ClSubtractionValidate(const TensorInfo& input0, } } //namespace armnn - -template class armnn::ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>; -template class armnn::ClSubtractionWorkload<armnn::DataType::QuantisedAsymm8>; diff --git a/src/backends/cl/workloads/ClSubtractionWorkload.hpp b/src/backends/cl/workloads/ClSubtractionWorkload.hpp index 67b219b09d..7dd608bf8a 100644 --- a/src/backends/cl/workloads/ClSubtractionWorkload.hpp +++ b/src/backends/cl/workloads/ClSubtractionWorkload.hpp @@ -12,8 +12,7 @@ namespace armnn { -template <armnn::DataType... dataTypes> -class ClSubtractionWorkload : public TypedWorkload<SubtractionQueueDescriptor, dataTypes...> +class ClSubtractionWorkload : public BaseWorkload<SubtractionQueueDescriptor> { public: ClSubtractionWorkload(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info); diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp index e48cd97d6f..9b68546c93 100644 --- a/src/backends/test/CreateWorkloadCl.cpp +++ b/src/backends/test/CreateWorkloadCl.cpp @@ -85,7 +85,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) { - ClCreateArithmethicWorkloadTest<ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>, + ClCreateArithmethicWorkloadTest<ClSubtractionWorkload, SubtractionQueueDescriptor, SubtractionLayer, armnn::DataType::Float32>(); @@ -93,7 +93,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) { - ClCreateArithmethicWorkloadTest<ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>, + ClCreateArithmethicWorkloadTest<ClSubtractionWorkload, SubtractionQueueDescriptor, SubtractionLayer, armnn::DataType::Float16>(); |