diff options
author | Nattapat Chaimanowong <nattapat.chaimanowong@arm.com> | 2018-10-11 10:29:15 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-22 16:57:53 +0100 |
commit | a76698c34941ad5cf67fe114be05b038a31d98a7 (patch) | |
tree | 482b56a4449cc259708f899356bfa0c1610b05aa /src/backends | |
parent | ac9e096a574db91fbcc42c2ee919a1d1e57b7fd3 (diff) | |
download | armnn-a76698c34941ad5cf67fe114be05b038a31d98a7.tar.gz |
IVGCVSW-1951 Remove type templating from ClReshapeWorkload
Change-Id: I5349629bc5b36e5b5029a158bf888c09c3bda4b0
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/backend.mk | 3 | ||||
-rw-r--r-- | src/backends/cl/test/ClCreateWorkloadTests.cpp | 10 | ||||
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 6 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClReshapeFloatWorkload.cpp | 33 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClReshapeUint8Workload.hpp | 29 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClReshapeWorkload.cpp (renamed from src/backends/cl/workloads/ClReshapeUint8Workload.cpp) | 15 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClReshapeWorkload.hpp (renamed from src/backends/cl/workloads/ClReshapeFloatWorkload.hpp) | 6 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 3 |
9 files changed, 20 insertions, 87 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index fa86840e5f..499d13d8e1 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -229,7 +229,7 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueue std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<ClReshapeFloatWorkload, ClReshapeUint8Workload>(descriptor, info); + return std::make_unique<ClReshapeWorkload>(descriptor, info); } std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor, diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index 810bb20859..4375d9496c 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -30,8 +30,7 @@ BACKEND_SOURCES := \ workloads/ClPadWorkload.cpp \ workloads/ClPermuteWorkload.cpp \ workloads/ClPooling2dWorkload.cpp \ - workloads/ClReshapeFloatWorkload.cpp \ - workloads/ClReshapeUint8Workload.cpp \ + workloads/ClReshapeWorkload.cpp \ workloads/ClResizeBilinearFloatWorkload.cpp \ workloads/ClSoftmaxBaseWorkload.cpp \ workloads/ClSoftmaxFloatWorkload.cpp \ diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 29f7cddc44..67f3e3c5bb 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -380,13 +380,13 @@ BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16NhwcWorkload) ClPooling2dWorkloadTest<armnn::DataType::Float16>(DataLayout::NHWC); } -template <typename ReshapeWorkloadType, typename armnn::DataType DataType> +template <typename armnn::DataType DataType> static void ClCreateReshapeWorkloadTest() { Graph graph; ClWorkloadFactory factory; - auto workload = CreateReshapeWorkloadTest<ReshapeWorkloadType, DataType>(factory, graph); + auto workload = CreateReshapeWorkloadTest<ClReshapeWorkload, DataType>(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateReshapeWorkloadTest). ReshapeQueueDescriptor queueDescriptor = workload->GetData(); @@ -399,17 +399,17 @@ static void ClCreateReshapeWorkloadTest() BOOST_AUTO_TEST_CASE(CreateReshapeFloatWorkload) { - ClCreateReshapeWorkloadTest<ClReshapeFloatWorkload, armnn::DataType::Float32>(); + ClCreateReshapeWorkloadTest<armnn::DataType::Float32>(); } BOOST_AUTO_TEST_CASE(CreateReshapeFloat16Workload) { - ClCreateReshapeWorkloadTest<ClReshapeFloatWorkload, armnn::DataType::Float16>(); + ClCreateReshapeWorkloadTest<armnn::DataType::Float16>(); } BOOST_AUTO_TEST_CASE(CreateReshapeUint8Workload) { - ClCreateReshapeWorkloadTest<ClReshapeUint8Workload, armnn::DataType::QuantisedAsymm8>(); + ClCreateReshapeWorkloadTest<armnn::DataType::QuantisedAsymm8>(); } template <typename SoftmaxWorkloadType, typename armnn::DataType DataType> diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 222748b89a..5bd217295e 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -41,10 +41,8 @@ list(APPEND armnnClBackendWorkloads_sources ClPermuteWorkload.hpp ClPooling2dWorkload.cpp ClPooling2dWorkload.hpp - ClReshapeFloatWorkload.cpp - ClReshapeFloatWorkload.hpp - ClReshapeUint8Workload.cpp - ClReshapeUint8Workload.hpp + ClReshapeWorkload.cpp + ClReshapeWorkload.hpp ClResizeBilinearFloatWorkload.cpp ClResizeBilinearFloatWorkload.hpp ClSoftmaxBaseWorkload.cpp diff --git a/src/backends/cl/workloads/ClReshapeFloatWorkload.cpp b/src/backends/cl/workloads/ClReshapeFloatWorkload.cpp deleted file mode 100644 index 4da3bbd703..0000000000 --- a/src/backends/cl/workloads/ClReshapeFloatWorkload.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClReshapeFloatWorkload.hpp" -#include <backends/cl/ClTensorHandle.hpp> -#include <backends/CpuTensorHandle.hpp> - -#include "ClWorkloadUtils.hpp" - -namespace armnn -{ - -ClReshapeFloatWorkload::ClReshapeFloatWorkload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) - : FloatWorkload<ReshapeQueueDescriptor>(descriptor, info) -{ - m_Data.ValidateInputsOutputs("ClReshapeFloatWorkload", 1, 1); - - arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); - arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); - - m_Layer.configure(&input, &output); -} - -void ClReshapeFloatWorkload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClReshapeFloatWorkload_Execute"); - m_Layer.run(); -} - -} //namespace armnn - diff --git a/src/backends/cl/workloads/ClReshapeUint8Workload.hpp b/src/backends/cl/workloads/ClReshapeUint8Workload.hpp deleted file mode 100644 index 654437a4c1..0000000000 --- a/src/backends/cl/workloads/ClReshapeUint8Workload.hpp +++ /dev/null @@ -1,29 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include <backends/Workload.hpp> - -#include <arm_compute/runtime/CL/CLFunctions.h> - -namespace armnn -{ - -// Reshape -class ClReshapeUint8Workload : public Uint8Workload<ReshapeQueueDescriptor> -{ -public: - ClReshapeUint8Workload( const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info); - - void Execute() const override; - -private: - mutable arm_compute::CLReshapeLayer m_Layer; -}; - -} //namespace armnn - - diff --git a/src/backends/cl/workloads/ClReshapeUint8Workload.cpp b/src/backends/cl/workloads/ClReshapeWorkload.cpp index 8fbee151fc..43a53cb7a1 100644 --- a/src/backends/cl/workloads/ClReshapeUint8Workload.cpp +++ b/src/backends/cl/workloads/ClReshapeWorkload.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT // -#include "ClReshapeUint8Workload.hpp" +#include "ClReshapeWorkload.hpp" #include <backends/cl/ClTensorHandle.hpp> #include <backends/CpuTensorHandle.hpp> @@ -11,20 +11,21 @@ namespace armnn { -ClReshapeUint8Workload::ClReshapeUint8Workload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) - : Uint8Workload<ReshapeQueueDescriptor>(descriptor, info) + +ClReshapeWorkload::ClReshapeWorkload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload<ReshapeQueueDescriptor>(descriptor, info) { - m_Data.ValidateInputsOutputs("ClReshapeUint8Workload", 1, 1); + m_Data.ValidateInputsOutputs("ClReshapeWorkload", 1, 1); arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + m_Layer.configure(&input, &output); } -void ClReshapeUint8Workload::Execute() const +void ClReshapeWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT_CL("ClReshapeUint8Workload_Execute"); - + ARMNN_SCOPED_PROFILING_EVENT_CL("ClReshapeWorkload_Execute"); m_Layer.run(); } diff --git a/src/backends/cl/workloads/ClReshapeFloatWorkload.hpp b/src/backends/cl/workloads/ClReshapeWorkload.hpp index e5fc20ec8b..f949f764b2 100644 --- a/src/backends/cl/workloads/ClReshapeFloatWorkload.hpp +++ b/src/backends/cl/workloads/ClReshapeWorkload.hpp @@ -12,10 +12,10 @@ namespace armnn { -class ClReshapeFloatWorkload : public FloatWorkload<ReshapeQueueDescriptor> +class ClReshapeWorkload : public BaseWorkload<ReshapeQueueDescriptor> { public: - ClReshapeFloatWorkload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info); + ClReshapeWorkload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info); void Execute() const override; @@ -24,5 +24,3 @@ private: }; } //namespace armnn - - diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index e2dede68e2..63de744be5 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -21,8 +21,7 @@ #include "ClPermuteWorkload.hpp" #include "ClPadWorkload.hpp" #include "ClPooling2dWorkload.hpp" -#include "ClReshapeFloatWorkload.hpp" -#include "ClReshapeUint8Workload.hpp" +#include "ClReshapeWorkload.hpp" #include "ClResizeBilinearFloatWorkload.hpp" #include "ClSoftmaxFloatWorkload.hpp" #include "ClSoftmaxUint8Workload.hpp" |