aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TensorHandleStrategyTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/TensorHandleStrategyTest.cpp')
-rw-r--r--src/armnn/test/TensorHandleStrategyTest.cpp119
1 files changed, 106 insertions, 13 deletions
diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp
index 3bb1c68169..c391b04d97 100644
--- a/src/armnn/test/TensorHandleStrategyTest.cpp
+++ b/src/armnn/test/TensorHandleStrategyTest.cpp
@@ -50,9 +50,11 @@ public:
return nullptr;
}
- virtual const FactoryId GetId() const override { return m_Id; }
+ const FactoryId GetId() const override { return m_Id; }
- virtual bool SupportsSubTensors() const override { return true; }
+ bool SupportsSubTensors() const override { return true; }
+
+ MemorySourceFlags GetExportFlags() const override { return 1; }
private:
FactoryId m_Id = "UninitializedId";
@@ -60,6 +62,38 @@ private:
std::weak_ptr<IMemoryManager> m_MemMgr;
};
+class TestFactoryImport : public ITensorHandleFactory
+{
+public:
+ TestFactoryImport(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id)
+ : m_Id(id)
+ , m_MemMgr(mgr)
+ {}
+
+ std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
+ TensorShape const& subTensorShape,
+ unsigned int const* subTensorOrigin) const override
+ {
+ return nullptr;
+ }
+
+ std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override
+ {
+ return nullptr;
+ }
+
+ const FactoryId GetId() const override { return m_Id; }
+
+ bool SupportsSubTensors() const override { return true; }
+
+ MemorySourceFlags GetImportFlags() const override { return 1; }
+
+private:
+ FactoryId m_Id = "ImporterId";
+
+ std::weak_ptr<IMemoryManager> m_MemMgr;
+};
+
class TestBackendA : public IBackendInternal
{
public:
@@ -173,6 +207,42 @@ private:
BackendId m_Id = "BackendC";
};
+class TestBackendD : public IBackendInternal
+{
+public:
+ TestBackendD() = default;
+
+ const BackendId& GetId() const override { return m_Id; }
+
+ IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
+ {
+ return IWorkloadFactoryPtr{};
+ }
+
+ IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
+ {
+ return ILayerSupportSharedPtr{};
+ }
+
+ std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
+ {
+ return std::vector<ITensorHandleFactory::FactoryId>{
+ "TestHandleFactoryD1"
+ };
+ }
+
+ void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
+ {
+ auto mgr = std::make_shared<TestMemMgr>();
+
+ registry.RegisterMemoryManager(mgr);
+ registry.RegisterFactory(std::make_unique<TestFactoryImport>(mgr, "TestHandleFactoryD1"));
+ }
+
+private:
+ BackendId m_Id = "BackendD";
+};
+
BOOST_AUTO_TEST_SUITE(TensorHandle)
@@ -200,16 +270,19 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
auto backendA = std::make_unique<TestBackendA>();
auto backendB = std::make_unique<TestBackendB>();
auto backendC = std::make_unique<TestBackendC>();
+ auto backendD = std::make_unique<TestBackendD>();
TensorHandleFactoryRegistry registry;
backendA->RegisterTensorHandleFactories(registry);
backendB->RegisterTensorHandleFactories(registry);
backendC->RegisterTensorHandleFactories(registry);
+ backendD->RegisterTensorHandleFactories(registry);
BackendsMap backends;
backends["BackendA"] = std::move(backendA);
backends["BackendB"] = std::move(backendB);
backends["BackendC"] = std::move(backendC);
+ backends["BackendD"] = std::move(backendD);
armnn::Graph graph;
@@ -226,13 +299,17 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3");
softmaxLayer3->SetBackendId("BackendC");
+ armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4");
+ softmaxLayer4->SetBackendId("BackendD");
+
armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output");
outputLayer->SetBackendId("BackendA");
inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0));
softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0));
softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0));
- softmaxLayer3->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0));
+ softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
graph.TopologicalSort();
@@ -246,29 +323,45 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0);
OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0);
OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0);
+ OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0);
// Check that the correct factory was selected
BOOST_TEST(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryA1");
BOOST_TEST(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
BOOST_TEST(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
BOOST_TEST(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1");
+ BOOST_TEST(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
// Check that the correct strategy was selected
- BOOST_TEST((inputLayerOut.GetMemoryStrategyForConnection(0) == MemoryStrategy::DirectCompatibility));
- BOOST_TEST((softmaxLayer1Out.GetMemoryStrategyForConnection(0) == MemoryStrategy::DirectCompatibility));
- BOOST_TEST((softmaxLayer2Out.GetMemoryStrategyForConnection(0) == MemoryStrategy::CopyToTarget));
- BOOST_TEST((softmaxLayer3Out.GetMemoryStrategyForConnection(0) == MemoryStrategy::DirectCompatibility));
-
- graph.AddCopyLayers(backends, registry);
- int count= 0;
- graph.ForEachLayer([&count](Layer* layer)
+ BOOST_TEST((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
+ BOOST_TEST((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
+ BOOST_TEST((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget));
+ BOOST_TEST((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget));
+ BOOST_TEST((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
+
+ graph.AddCompatibilityLayers(backends, registry);
+
+ // Test for copy layers
+ int copyCount= 0;
+ graph.ForEachLayer([&copyCount](Layer* layer)
{
if (layer->GetType() == LayerType::MemCopy)
{
- count++;
+ copyCount++;
+ }
+ });
+ BOOST_TEST(copyCount == 1);
+
+ // Test for import layers
+ int importCount= 0;
+ graph.ForEachLayer([&importCount](Layer *layer)
+ {
+ if (layer->GetType() == LayerType::MemImport)
+ {
+ importCount++;
}
});
- BOOST_TEST(count == 1);
+ BOOST_TEST(importCount == 1);
}
BOOST_AUTO_TEST_SUITE_END()