aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2018-10-10 17:18:35 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:53 +0100
commitac9e096a574db91fbcc42c2ee919a1d1e57b7fd3 (patch)
tree9490fcc4bef5a88e074820fa88b088b785d55810
parent02f8bc10bc670dd694eeda2db8e0a43a1c84320b (diff)
downloadarmnn-ac9e096a574db91fbcc42c2ee919a1d1e57b7fd3.tar.gz
IVGCVSW-1951 Remove type templating from ClPooling2dWorkload
Change-Id: Iaa3158487b58964d8a3b98acadde7c10172a3860
-rw-r--r--src/backends/cl/ClLayerSupport.cpp2
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp2
-rw-r--r--src/backends/cl/backend.mk4
-rw-r--r--src/backends/cl/test/ClCreateWorkloadTests.cpp12
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt8
-rw-r--r--src/backends/cl/workloads/ClPooling2dBaseWorkload.hpp33
-rw-r--r--src/backends/cl/workloads/ClPooling2dFloatWorkload.cpp26
-rw-r--r--src/backends/cl/workloads/ClPooling2dFloatWorkload.hpp22
-rw-r--r--src/backends/cl/workloads/ClPooling2dUint8Workload.cpp27
-rw-r--r--src/backends/cl/workloads/ClPooling2dUint8Workload.hpp25
-rw-r--r--src/backends/cl/workloads/ClPooling2dWorkload.cpp (renamed from src/backends/cl/workloads/ClPooling2dBaseWorkload.cpp)20
-rw-r--r--src/backends/cl/workloads/ClPooling2dWorkload.hpp33
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp3
13 files changed, 57 insertions, 160 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 124fc8c230..9088da8645 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -26,7 +26,7 @@
#include "workloads/ClNormalizationFloatWorkload.hpp"
#include "workloads/ClPadWorkload.hpp"
#include "workloads/ClPermuteWorkload.hpp"
-#include "workloads/ClPooling2dBaseWorkload.hpp"
+#include "workloads/ClPooling2dWorkload.hpp"
#include "workloads/ClSoftmaxBaseWorkload.hpp"
#include "workloads/ClSubtractionWorkload.hpp"
#endif
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index d293759bc9..fa86840e5f 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -138,7 +138,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const Permute
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<ClPooling2dFloatWorkload, ClPooling2dUint8Workload>(descriptor, info);
+ return std::make_unique<ClPooling2dWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index d5aa945ee2..810bb20859 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -29,9 +29,7 @@ BACKEND_SOURCES := \
workloads/ClNormalizationFloatWorkload.cpp \
workloads/ClPadWorkload.cpp \
workloads/ClPermuteWorkload.cpp \
- workloads/ClPooling2dBaseWorkload.cpp \
- workloads/ClPooling2dFloatWorkload.cpp \
- workloads/ClPooling2dUint8Workload.cpp \
+ workloads/ClPooling2dWorkload.cpp \
workloads/ClReshapeFloatWorkload.cpp \
workloads/ClReshapeUint8Workload.cpp \
workloads/ClResizeBilinearFloatWorkload.cpp \
diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp
index d3bae440bd..29f7cddc44 100644
--- a/src/backends/cl/test/ClCreateWorkloadTests.cpp
+++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp
@@ -338,13 +338,13 @@ BOOST_AUTO_TEST_CASE(CreateNormalizationFloat16NhwcWorkload)
ClNormalizationWorkloadTest<ClNormalizationFloatWorkload, armnn::DataType::Float16>(DataLayout::NHWC);
}
-template <typename Pooling2dWorkloadType, typename armnn::DataType DataType>
+template <typename armnn::DataType DataType>
static void ClPooling2dWorkloadTest(DataLayout dataLayout)
{
Graph graph;
ClWorkloadFactory factory;
- auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout);
+ auto workload = CreatePooling2dWorkloadTest<ClPooling2dWorkload, DataType>(factory, graph, dataLayout);
std::initializer_list<unsigned int> inputShape = (dataLayout == DataLayout::NCHW) ?
std::initializer_list<unsigned int>({3, 2, 5, 5}) : std::initializer_list<unsigned int>({3, 5, 5, 2});
@@ -362,22 +362,22 @@ static void ClPooling2dWorkloadTest(DataLayout dataLayout)
BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNchwWorkload)
{
- ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
+ ClPooling2dWorkloadTest<armnn::DataType::Float32>(DataLayout::NCHW);
}
BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNhwcWorkload)
{
- ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
+ ClPooling2dWorkloadTest<armnn::DataType::Float32>(DataLayout::NHWC);
}
BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16NchwWorkload)
{
- ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
+ ClPooling2dWorkloadTest<armnn::DataType::Float16>(DataLayout::NCHW);
}
BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16NhwcWorkload)
{
- ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>(DataLayout::NHWC);
+ ClPooling2dWorkloadTest<armnn::DataType::Float16>(DataLayout::NHWC);
}
template <typename ReshapeWorkloadType, typename armnn::DataType DataType>
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index e7877f2d45..222748b89a 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -39,12 +39,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClPadWorkload.hpp
ClPermuteWorkload.cpp
ClPermuteWorkload.hpp
- ClPooling2dBaseWorkload.cpp
- ClPooling2dBaseWorkload.hpp
- ClPooling2dFloatWorkload.cpp
- ClPooling2dFloatWorkload.hpp
- ClPooling2dUint8Workload.cpp
- ClPooling2dUint8Workload.hpp
+ ClPooling2dWorkload.cpp
+ ClPooling2dWorkload.hpp
ClReshapeFloatWorkload.cpp
ClReshapeFloatWorkload.hpp
ClReshapeUint8Workload.cpp
diff --git a/src/backends/cl/workloads/ClPooling2dBaseWorkload.hpp b/src/backends/cl/workloads/ClPooling2dBaseWorkload.hpp
deleted file mode 100644
index 8f9db08ddc..0000000000
--- a/src/backends/cl/workloads/ClPooling2dBaseWorkload.hpp
+++ /dev/null
@@ -1,33 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backends/Workload.hpp>
-
-#include <arm_compute/runtime/CL/CLFunctions.h>
-
-namespace armnn
-{
-
-arm_compute::Status ClPooling2dWorkloadValidate(const TensorInfo& input,
- const TensorInfo& output,
- const Pooling2dDescriptor& descriptor);
-
-// Base class template providing an implementation of the Pooling2d layer common to all data types.
-template <armnn::DataType... dataTypes>
-class ClPooling2dBaseWorkload : public TypedWorkload<Pooling2dQueueDescriptor, dataTypes...>
-{
-public:
- using TypedWorkload<Pooling2dQueueDescriptor, dataTypes...>::m_Data;
-
- ClPooling2dBaseWorkload(const Pooling2dQueueDescriptor& descriptor, const WorkloadInfo& info,
- const std::string& name);
-
-protected:
- mutable arm_compute::CLPoolingLayer m_PoolingLayer;
-};
-
-} //namespace armnn
diff --git a/src/backends/cl/workloads/ClPooling2dFloatWorkload.cpp b/src/backends/cl/workloads/ClPooling2dFloatWorkload.cpp
deleted file mode 100644
index dc9d17f0ae..0000000000
--- a/src/backends/cl/workloads/ClPooling2dFloatWorkload.cpp
+++ /dev/null
@@ -1,26 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClPooling2dFloatWorkload.hpp"
-
-#include "ClWorkloadUtils.hpp"
-
-namespace armnn
-{
-
-ClPooling2dFloatWorkload::ClPooling2dFloatWorkload(const Pooling2dQueueDescriptor& descriptor,
- const WorkloadInfo& info)
- : ClPooling2dBaseWorkload<DataType::Float16, DataType::Float32>(descriptor, info, "ClPooling2dFloatWorkload")
-{
-}
-
-void ClPooling2dFloatWorkload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT_CL("ClPooling2dFloatWorkload_Execute");
- m_PoolingLayer.run();
-}
-
-} //namespace armnn
-
diff --git a/src/backends/cl/workloads/ClPooling2dFloatWorkload.hpp b/src/backends/cl/workloads/ClPooling2dFloatWorkload.hpp
deleted file mode 100644
index ba9294c40f..0000000000
--- a/src/backends/cl/workloads/ClPooling2dFloatWorkload.hpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backends/Workload.hpp>
-
-#include "ClPooling2dBaseWorkload.hpp"
-
-namespace armnn
-{
-class ClPooling2dFloatWorkload : public ClPooling2dBaseWorkload<DataType::Float16, DataType::Float32>
-{
-public:
- ClPooling2dFloatWorkload(const Pooling2dQueueDescriptor& descriptor, const WorkloadInfo& info);
- void Execute() const override;
-
-};
-
-} //namespace armnn
diff --git a/src/backends/cl/workloads/ClPooling2dUint8Workload.cpp b/src/backends/cl/workloads/ClPooling2dUint8Workload.cpp
deleted file mode 100644
index 0b4b15f806..0000000000
--- a/src/backends/cl/workloads/ClPooling2dUint8Workload.cpp
+++ /dev/null
@@ -1,27 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClPooling2dUint8Workload.hpp"
-
-#include "ClWorkloadUtils.hpp"
-
-namespace armnn
-{
-
-ClPooling2dUint8Workload::ClPooling2dUint8Workload(const Pooling2dQueueDescriptor& descriptor,
- const WorkloadInfo& info)
- : ClPooling2dBaseWorkload<DataType::QuantisedAsymm8>(descriptor, info, "ClPooling2dUint8Workload")
-{
-}
-
-void ClPooling2dUint8Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT_CL("ClPooling2dUint8Workload_Execute");
- m_PoolingLayer.run();
-}
-
-} //namespace armnn
-
-
diff --git a/src/backends/cl/workloads/ClPooling2dUint8Workload.hpp b/src/backends/cl/workloads/ClPooling2dUint8Workload.hpp
deleted file mode 100644
index b07f955343..0000000000
--- a/src/backends/cl/workloads/ClPooling2dUint8Workload.hpp
+++ /dev/null
@@ -1,25 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backends/Workload.hpp>
-
-#include "ClPooling2dBaseWorkload.hpp"
-
-namespace armnn
-{
-
-class ClPooling2dUint8Workload : public ClPooling2dBaseWorkload<DataType::QuantisedAsymm8>
-{
-public:
- ClPooling2dUint8Workload(const Pooling2dQueueDescriptor& descriptor, const WorkloadInfo& info);
- void Execute() const override;
-
-};
-
-} //namespace armnn
-
-
diff --git a/src/backends/cl/workloads/ClPooling2dBaseWorkload.cpp b/src/backends/cl/workloads/ClPooling2dWorkload.cpp
index e61ad4f28e..255f57341e 100644
--- a/src/backends/cl/workloads/ClPooling2dBaseWorkload.cpp
+++ b/src/backends/cl/workloads/ClPooling2dWorkload.cpp
@@ -3,12 +3,14 @@
// SPDX-License-Identifier: MIT
//
-#include "ClPooling2dBaseWorkload.hpp"
+#include "ClPooling2dWorkload.hpp"
#include <backends/cl/ClLayerSupport.hpp>
#include <backends/cl/ClTensorHandle.hpp>
#include <backends/aclCommon/ArmComputeUtils.hpp>
#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
+#include "ClWorkloadUtils.hpp"
+
namespace armnn
{
using namespace armcomputetensorutils;
@@ -25,12 +27,11 @@ arm_compute::Status ClPooling2dWorkloadValidate(const TensorInfo& input,
return arm_compute::CLPoolingLayer::validate(&aclInputInfo, &aclOutputInfo, layerInfo);
}
-template <armnn::DataType... dataTypes>
-ClPooling2dBaseWorkload<dataTypes...>::ClPooling2dBaseWorkload(
- const Pooling2dQueueDescriptor& descriptor, const WorkloadInfo& info, const std::string& name)
- : TypedWorkload<Pooling2dQueueDescriptor, dataTypes...>(descriptor, info)
+ClPooling2dWorkload::ClPooling2dWorkload(
+ const Pooling2dQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<Pooling2dQueueDescriptor>(descriptor, info)
{
- m_Data.ValidateInputsOutputs(name, 1, 1);
+ m_Data.ValidateInputsOutputs("ClPooling2dWorkload", 1, 1);
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
@@ -45,7 +46,10 @@ ClPooling2dBaseWorkload<dataTypes...>::ClPooling2dBaseWorkload(
m_PoolingLayer.configure(&input, &output, layerInfo);
}
-template class ClPooling2dBaseWorkload<DataType::Float16, DataType::Float32>;
-template class ClPooling2dBaseWorkload<DataType::QuantisedAsymm8>;
+void ClPooling2dWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClPooling2dWorkload_Execute");
+ m_PoolingLayer.run();
+}
}
diff --git a/src/backends/cl/workloads/ClPooling2dWorkload.hpp b/src/backends/cl/workloads/ClPooling2dWorkload.hpp
new file mode 100644
index 0000000000..0812e33a52
--- /dev/null
+++ b/src/backends/cl/workloads/ClPooling2dWorkload.hpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backends/Workload.hpp>
+
+#include <arm_compute/runtime/CL/CLFunctions.h>
+
+namespace armnn
+{
+
+arm_compute::Status ClPooling2dWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const Pooling2dDescriptor& descriptor);
+
+class ClPooling2dWorkload : public BaseWorkload<Pooling2dQueueDescriptor>
+{
+public:
+ using BaseWorkload<Pooling2dQueueDescriptor>::m_Data;
+
+ ClPooling2dWorkload(const Pooling2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+
+ void Execute() const override;
+
+private:
+ mutable arm_compute::CLPoolingLayer m_PoolingLayer;
+};
+
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index 8257218291..e2dede68e2 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -20,8 +20,7 @@
#include "ClNormalizationFloatWorkload.hpp"
#include "ClPermuteWorkload.hpp"
#include "ClPadWorkload.hpp"
-#include "ClPooling2dFloatWorkload.hpp"
-#include "ClPooling2dUint8Workload.hpp"
+#include "ClPooling2dWorkload.hpp"
#include "ClReshapeFloatWorkload.hpp"
#include "ClReshapeUint8Workload.hpp"
#include "ClResizeBilinearFloatWorkload.hpp"