aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-10-23 16:09:36 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-25 09:49:58 +0100
commit79141b9662547eeefb3ad533637223de40726e12 (patch)
tree9869465c1769fb1f6e2bf86306fbac3bc1b6add4
parent29c75de868ac3a86a70b25f8da0d0c7e47d40803 (diff)
downloadarmnn-79141b9662547eeefb3ad533637223de40726e12.tar.gz
IVGCVSW-2071 : remove GetCompute() from the WorkloadFactory interface
Change-Id: I44a9d26d1a5d876d381aee4c6450af62811d0dbb
-rw-r--r--src/armnn/test/CreateWorkload.hpp2
-rw-r--r--src/backends/WorkloadFactory.hpp2
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp13
-rw-r--r--src/backends/cl/ClWorkloadFactory.hpp2
-rw-r--r--src/backends/neon/NeonWorkloadFactory.cpp13
-rw-r--r--src/backends/neon/NeonWorkloadFactory.hpp2
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp13
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp2
-rw-r--r--src/backends/test/NormTestImpl.hpp4
-rw-r--r--src/backends/test/Pooling2dTestImpl.hpp8
10 files changed, 47 insertions, 14 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index aac0a4ae6d..5308a1c1dc 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -32,7 +32,7 @@ std::unique_ptr<Workload> MakeAndCheckWorkload(Layer& layer, Graph& graph, const
BOOST_TEST(workload.get() == boost::polymorphic_downcast<Workload*>(workload.get()),
"Cannot convert to derived class");
std::string reasonIfUnsupported;
- layer.SetBackendId(factory.GetCompute());
+ layer.SetBackendId(factory.GetBackendId());
BOOST_TEST(factory.IsLayerSupported(layer, layer.GetDataType(), reasonIfUnsupported));
return std::unique_ptr<Workload>(static_cast<Workload*>(workload.release()));
}
diff --git a/src/backends/WorkloadFactory.hpp b/src/backends/WorkloadFactory.hpp
index 2d482e0911..2f422ab4f6 100644
--- a/src/backends/WorkloadFactory.hpp
+++ b/src/backends/WorkloadFactory.hpp
@@ -21,7 +21,7 @@ class IWorkloadFactory
public:
virtual ~IWorkloadFactory() { }
- virtual Compute GetCompute() const = 0;
+ virtual const BackendId& GetBackendId() const = 0;
/// Informs the memory manager that the network is finalized and ready for execution.
virtual void Finalize() { }
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index c697d90950..fd92db34d5 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
#include "ClWorkloadFactory.hpp"
+#include "ClBackendId.hpp"
#include <armnn/Exceptions.hpp>
#include <armnn/Utils.hpp>
@@ -34,11 +35,21 @@
namespace armnn
{
+namespace
+{
+static const BackendId s_Id{ClBackendId()};
+}
+
bool ClWorkloadFactory::IsLayerSupported(const Layer& layer,
Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
- return IWorkloadFactory::IsLayerSupported(Compute::GpuAcc, layer, dataType, outReasonIfUnsupported);
+ return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
+}
+
+const BackendId& ClWorkloadFactory::GetBackendId() const
+{
+ return s_Id;
}
#ifdef ARMCOMPUTECL_ENABLED
diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp
index 1441b71e61..ba7cf6931f 100644
--- a/src/backends/cl/ClWorkloadFactory.hpp
+++ b/src/backends/cl/ClWorkloadFactory.hpp
@@ -19,7 +19,7 @@ class ClWorkloadFactory : public IWorkloadFactory
public:
ClWorkloadFactory();
- virtual Compute GetCompute() const override { return Compute::GpuAcc; }
+ const BackendId& GetBackendId() const override;
static bool IsLayerSupported(const Layer& layer,
Optional<DataType> dataType,
diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp
index f0a9e76de1..c16d383554 100644
--- a/src/backends/neon/NeonWorkloadFactory.cpp
+++ b/src/backends/neon/NeonWorkloadFactory.cpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
#include "NeonWorkloadFactory.hpp"
+#include "NeonBackendId.hpp"
#include <armnn/Utils.hpp>
#include <backends/CpuTensorHandle.hpp>
#include <Layer.hpp>
@@ -25,11 +26,21 @@
namespace armnn
{
+namespace
+{
+static const BackendId s_Id{NeonBackendId()};
+}
+
bool NeonWorkloadFactory::IsLayerSupported(const Layer& layer,
Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
- return IWorkloadFactory::IsLayerSupported(Compute::CpuAcc, layer, dataType, outReasonIfUnsupported);
+ return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
+}
+
+const BackendId& NeonWorkloadFactory::GetBackendId() const
+{
+ return s_Id;
}
#ifdef ARMCOMPUTENEON_ENABLED
diff --git a/src/backends/neon/NeonWorkloadFactory.hpp b/src/backends/neon/NeonWorkloadFactory.hpp
index d1dd2c85fe..030e982a20 100644
--- a/src/backends/neon/NeonWorkloadFactory.hpp
+++ b/src/backends/neon/NeonWorkloadFactory.hpp
@@ -20,7 +20,7 @@ class NeonWorkloadFactory : public IWorkloadFactory
public:
NeonWorkloadFactory();
- virtual Compute GetCompute() const override { return Compute::CpuAcc; }
+ const BackendId& GetBackendId() const override;
static bool IsLayerSupported(const Layer& layer,
Optional<DataType> dataType,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 783e5fba2e..864ffdbf4f 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -6,6 +6,7 @@
#include <backends/MemCopyWorkload.hpp>
#include <backends/MakeWorkloadHelper.hpp>
#include "RefWorkloadFactory.hpp"
+#include "RefBackendId.hpp"
#include "workloads/RefWorkloads.hpp"
#include "Layer.hpp"
@@ -14,6 +15,11 @@
namespace armnn
{
+namespace
+{
+static const BackendId s_Id{RefBackendId()};
+}
+
template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
const WorkloadInfo& info) const
@@ -25,11 +31,16 @@ RefWorkloadFactory::RefWorkloadFactory()
{
}
+const BackendId& RefWorkloadFactory::GetBackendId() const
+{
+ return s_Id;
+}
+
bool RefWorkloadFactory::IsLayerSupported(const Layer& layer,
Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
- return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported);
+ return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
}
std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index ef2e1abfaa..be0dafc159 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -33,7 +33,7 @@ public:
explicit RefWorkloadFactory();
virtual ~RefWorkloadFactory() {}
- virtual Compute GetCompute() const override { return Compute::CpuRef; }
+ const BackendId& GetBackendId() const override;
static bool IsLayerSupported(const Layer& layer,
Optional<DataType> dataType,
diff --git a/src/backends/test/NormTestImpl.hpp b/src/backends/test/NormTestImpl.hpp
index f4e6aea008..de954b95e0 100644
--- a/src/backends/test/NormTestImpl.hpp
+++ b/src/backends/test/NormTestImpl.hpp
@@ -308,10 +308,10 @@ LayerTestResult<float,4> CompareNormalizationTestImpl(armnn::IWorkloadFactory& w
SetWorkloadOutput(refData, refInfo, 0, outputTensorInfo, outputHandleRef.get());
// Don't execute if Normalization is not supported for the method and channel types, as an exception will be raised.
- armnn::Compute compute = workloadFactory.GetCompute();
+ armnn::BackendId backend = workloadFactory.GetBackendId();
const size_t reasonIfUnsupportedMaxLen = 255;
char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1];
- ret.supported = armnn::IsNormalizationSupported(compute, inputTensorInfo, outputTensorInfo, data.m_Parameters,
+ ret.supported = armnn::IsNormalizationSupported(backend, inputTensorInfo, outputTensorInfo, data.m_Parameters,
reasonIfUnsupported, reasonIfUnsupportedMaxLen);
if (!ret.supported)
{
diff --git a/src/backends/test/Pooling2dTestImpl.hpp b/src/backends/test/Pooling2dTestImpl.hpp
index 29263af9bc..90be2897e8 100644
--- a/src/backends/test/Pooling2dTestImpl.hpp
+++ b/src/backends/test/Pooling2dTestImpl.hpp
@@ -77,10 +77,10 @@ LayerTestResult<T, 4> SimplePooling2dTestImpl(armnn::IWorkloadFactory& workloadF
AddOutputToWorkload(queueDescriptor, workloadInfo, outputTensorInfo, outputHandle.get());
// Don't execute if Pooling is not supported, as an exception will be raised.
- armnn::Compute compute = workloadFactory.GetCompute();
+ armnn::BackendId backend = workloadFactory.GetBackendId();
const size_t reasonIfUnsupportedMaxLen = 255;
char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1];
- result.supported = armnn::IsPooling2dSupported(compute, inputTensorInfo, outputTensorInfo,
+ result.supported = armnn::IsPooling2dSupported(backend, inputTensorInfo, outputTensorInfo,
queueDescriptor.m_Parameters,
reasonIfUnsupported, reasonIfUnsupportedMaxLen);
if (!result.supported)
@@ -650,10 +650,10 @@ LayerTestResult<T, 4> ComparePooling2dTestCommon(armnn::IWorkloadFactory& worklo
std::unique_ptr<armnn::ITensorHandle> inputHandleRef = refWorkloadFactory.CreateTensorHandle(inputTensorInfo);
// Don't execute if Pooling is not supported, as an exception will be raised.
- armnn::Compute compute = workloadFactory.GetCompute();
+ armnn::BackendId backend = workloadFactory.GetBackendId();
const size_t reasonIfUnsupportedMaxLen = 255;
char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1];
- comparisonResult.supported = armnn::IsPooling2dSupported(compute, inputTensorInfo, outputTensorInfo,
+ comparisonResult.supported = armnn::IsPooling2dSupported(backend, inputTensorInfo, outputTensorInfo,
data.m_Parameters,
reasonIfUnsupported, reasonIfUnsupportedMaxLen);
if (!comparisonResult.supported)