From 79141b9662547eeefb3ad533637223de40726e12 Mon Sep 17 00:00:00 2001 From: David Beck Date: Tue, 23 Oct 2018 16:09:36 +0100 Subject: IVGCVSW-2071 : remove GetCompute() from the WorkloadFactory interface Change-Id: I44a9d26d1a5d876d381aee4c6450af62811d0dbb --- src/armnn/test/CreateWorkload.hpp | 2 +- src/backends/WorkloadFactory.hpp | 2 +- src/backends/cl/ClWorkloadFactory.cpp | 13 ++++++++++++- src/backends/cl/ClWorkloadFactory.hpp | 2 +- src/backends/neon/NeonWorkloadFactory.cpp | 13 ++++++++++++- src/backends/neon/NeonWorkloadFactory.hpp | 2 +- src/backends/reference/RefWorkloadFactory.cpp | 13 ++++++++++++- src/backends/reference/RefWorkloadFactory.hpp | 2 +- src/backends/test/NormTestImpl.hpp | 4 ++-- src/backends/test/Pooling2dTestImpl.hpp | 8 ++++---- 10 files changed, 47 insertions(+), 14 deletions(-) diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index aac0a4ae6d..5308a1c1dc 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -32,7 +32,7 @@ std::unique_ptr MakeAndCheckWorkload(Layer& layer, Graph& graph, const BOOST_TEST(workload.get() == boost::polymorphic_downcast(workload.get()), "Cannot convert to derived class"); std::string reasonIfUnsupported; - layer.SetBackendId(factory.GetCompute()); + layer.SetBackendId(factory.GetBackendId()); BOOST_TEST(factory.IsLayerSupported(layer, layer.GetDataType(), reasonIfUnsupported)); return std::unique_ptr(static_cast(workload.release())); } diff --git a/src/backends/WorkloadFactory.hpp b/src/backends/WorkloadFactory.hpp index 2d482e0911..2f422ab4f6 100644 --- a/src/backends/WorkloadFactory.hpp +++ b/src/backends/WorkloadFactory.hpp @@ -21,7 +21,7 @@ class IWorkloadFactory public: virtual ~IWorkloadFactory() { } - virtual Compute GetCompute() const = 0; + virtual const BackendId& GetBackendId() const = 0; /// Informs the memory manager that the network is finalized and ready for execution. virtual void Finalize() { } diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index c697d90950..fd92db34d5 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // #include "ClWorkloadFactory.hpp" +#include "ClBackendId.hpp" #include #include @@ -34,11 +35,21 @@ namespace armnn { +namespace +{ +static const BackendId s_Id{ClBackendId()}; +} + bool ClWorkloadFactory::IsLayerSupported(const Layer& layer, Optional dataType, std::string& outReasonIfUnsupported) { - return IWorkloadFactory::IsLayerSupported(Compute::GpuAcc, layer, dataType, outReasonIfUnsupported); + return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported); +} + +const BackendId& ClWorkloadFactory::GetBackendId() const +{ + return s_Id; } #ifdef ARMCOMPUTECL_ENABLED diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 1441b71e61..ba7cf6931f 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -19,7 +19,7 @@ class ClWorkloadFactory : public IWorkloadFactory public: ClWorkloadFactory(); - virtual Compute GetCompute() const override { return Compute::GpuAcc; } + const BackendId& GetBackendId() const override; static bool IsLayerSupported(const Layer& layer, Optional dataType, diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index f0a9e76de1..c16d383554 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // #include "NeonWorkloadFactory.hpp" +#include "NeonBackendId.hpp" #include #include #include @@ -25,11 +26,21 @@ namespace armnn { +namespace +{ +static const BackendId s_Id{NeonBackendId()}; +} + bool NeonWorkloadFactory::IsLayerSupported(const Layer& layer, Optional dataType, std::string& outReasonIfUnsupported) { - return IWorkloadFactory::IsLayerSupported(Compute::CpuAcc, layer, dataType, outReasonIfUnsupported); + return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported); +} + +const BackendId& NeonWorkloadFactory::GetBackendId() const +{ + return s_Id; } #ifdef ARMCOMPUTENEON_ENABLED diff --git a/src/backends/neon/NeonWorkloadFactory.hpp b/src/backends/neon/NeonWorkloadFactory.hpp index d1dd2c85fe..030e982a20 100644 --- a/src/backends/neon/NeonWorkloadFactory.hpp +++ b/src/backends/neon/NeonWorkloadFactory.hpp @@ -20,7 +20,7 @@ class NeonWorkloadFactory : public IWorkloadFactory public: NeonWorkloadFactory(); - virtual Compute GetCompute() const override { return Compute::CpuAcc; } + const BackendId& GetBackendId() const override; static bool IsLayerSupported(const Layer& layer, Optional dataType, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 783e5fba2e..864ffdbf4f 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -6,6 +6,7 @@ #include #include #include "RefWorkloadFactory.hpp" +#include "RefBackendId.hpp" #include "workloads/RefWorkloads.hpp" #include "Layer.hpp" @@ -14,6 +15,11 @@ namespace armnn { +namespace +{ +static const BackendId s_Id{RefBackendId()}; +} + template std::unique_ptr RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const @@ -25,11 +31,16 @@ RefWorkloadFactory::RefWorkloadFactory() { } +const BackendId& RefWorkloadFactory::GetBackendId() const +{ + return s_Id; +} + bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, Optional dataType, std::string& outReasonIfUnsupported) { - return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported); + return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported); } std::unique_ptr RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index ef2e1abfaa..be0dafc159 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -33,7 +33,7 @@ public: explicit RefWorkloadFactory(); virtual ~RefWorkloadFactory() {} - virtual Compute GetCompute() const override { return Compute::CpuRef; } + const BackendId& GetBackendId() const override; static bool IsLayerSupported(const Layer& layer, Optional dataType, diff --git a/src/backends/test/NormTestImpl.hpp b/src/backends/test/NormTestImpl.hpp index f4e6aea008..de954b95e0 100644 --- a/src/backends/test/NormTestImpl.hpp +++ b/src/backends/test/NormTestImpl.hpp @@ -308,10 +308,10 @@ LayerTestResult CompareNormalizationTestImpl(armnn::IWorkloadFactory& w SetWorkloadOutput(refData, refInfo, 0, outputTensorInfo, outputHandleRef.get()); // Don't execute if Normalization is not supported for the method and channel types, as an exception will be raised. - armnn::Compute compute = workloadFactory.GetCompute(); + armnn::BackendId backend = workloadFactory.GetBackendId(); const size_t reasonIfUnsupportedMaxLen = 255; char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1]; - ret.supported = armnn::IsNormalizationSupported(compute, inputTensorInfo, outputTensorInfo, data.m_Parameters, + ret.supported = armnn::IsNormalizationSupported(backend, inputTensorInfo, outputTensorInfo, data.m_Parameters, reasonIfUnsupported, reasonIfUnsupportedMaxLen); if (!ret.supported) { diff --git a/src/backends/test/Pooling2dTestImpl.hpp b/src/backends/test/Pooling2dTestImpl.hpp index 29263af9bc..90be2897e8 100644 --- a/src/backends/test/Pooling2dTestImpl.hpp +++ b/src/backends/test/Pooling2dTestImpl.hpp @@ -77,10 +77,10 @@ LayerTestResult SimplePooling2dTestImpl(armnn::IWorkloadFactory& workloadF AddOutputToWorkload(queueDescriptor, workloadInfo, outputTensorInfo, outputHandle.get()); // Don't execute if Pooling is not supported, as an exception will be raised. - armnn::Compute compute = workloadFactory.GetCompute(); + armnn::BackendId backend = workloadFactory.GetBackendId(); const size_t reasonIfUnsupportedMaxLen = 255; char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1]; - result.supported = armnn::IsPooling2dSupported(compute, inputTensorInfo, outputTensorInfo, + result.supported = armnn::IsPooling2dSupported(backend, inputTensorInfo, outputTensorInfo, queueDescriptor.m_Parameters, reasonIfUnsupported, reasonIfUnsupportedMaxLen); if (!result.supported) @@ -650,10 +650,10 @@ LayerTestResult ComparePooling2dTestCommon(armnn::IWorkloadFactory& worklo std::unique_ptr inputHandleRef = refWorkloadFactory.CreateTensorHandle(inputTensorInfo); // Don't execute if Pooling is not supported, as an exception will be raised. - armnn::Compute compute = workloadFactory.GetCompute(); + armnn::BackendId backend = workloadFactory.GetBackendId(); const size_t reasonIfUnsupportedMaxLen = 255; char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1]; - comparisonResult.supported = armnn::IsPooling2dSupported(compute, inputTensorInfo, outputTensorInfo, + comparisonResult.supported = armnn::IsPooling2dSupported(backend, inputTensorInfo, outputTensorInfo, data.m_Parameters, reasonIfUnsupported, reasonIfUnsupportedMaxLen); if (!comparisonResult.supported) -- cgit v1.2.1