diff options
Diffstat (limited to 'src/armnn/test/TensorHandleStrategyTest.cpp')
-rw-r--r-- | src/armnn/test/TensorHandleStrategyTest.cpp | 119 |
1 files changed, 106 insertions, 13 deletions
diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp index 3bb1c68169..c391b04d97 100644 --- a/src/armnn/test/TensorHandleStrategyTest.cpp +++ b/src/armnn/test/TensorHandleStrategyTest.cpp @@ -50,9 +50,11 @@ public: return nullptr; } - virtual const FactoryId GetId() const override { return m_Id; } + const FactoryId GetId() const override { return m_Id; } - virtual bool SupportsSubTensors() const override { return true; } + bool SupportsSubTensors() const override { return true; } + + MemorySourceFlags GetExportFlags() const override { return 1; } private: FactoryId m_Id = "UninitializedId"; @@ -60,6 +62,38 @@ private: std::weak_ptr<IMemoryManager> m_MemMgr; }; +class TestFactoryImport : public ITensorHandleFactory +{ +public: + TestFactoryImport(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id) + : m_Id(id) + , m_MemMgr(mgr) + {} + + std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, + TensorShape const& subTensorShape, + unsigned int const* subTensorOrigin) const override + { + return nullptr; + } + + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override + { + return nullptr; + } + + const FactoryId GetId() const override { return m_Id; } + + bool SupportsSubTensors() const override { return true; } + + MemorySourceFlags GetImportFlags() const override { return 1; } + +private: + FactoryId m_Id = "ImporterId"; + + std::weak_ptr<IMemoryManager> m_MemMgr; +}; + class TestBackendA : public IBackendInternal { public: @@ -173,6 +207,42 @@ private: BackendId m_Id = "BackendC"; }; +class TestBackendD : public IBackendInternal +{ +public: + TestBackendD() = default; + + const BackendId& GetId() const override { return m_Id; } + + IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override + { + return IWorkloadFactoryPtr{}; + } + + IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override + { + return ILayerSupportSharedPtr{}; + } + + std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override + { + return std::vector<ITensorHandleFactory::FactoryId>{ + "TestHandleFactoryD1" + }; + } + + void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override + { + auto mgr = std::make_shared<TestMemMgr>(); + + registry.RegisterMemoryManager(mgr); + registry.RegisterFactory(std::make_unique<TestFactoryImport>(mgr, "TestHandleFactoryD1")); + } + +private: + BackendId m_Id = "BackendD"; +}; + BOOST_AUTO_TEST_SUITE(TensorHandle) @@ -200,16 +270,19 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) auto backendA = std::make_unique<TestBackendA>(); auto backendB = std::make_unique<TestBackendB>(); auto backendC = std::make_unique<TestBackendC>(); + auto backendD = std::make_unique<TestBackendD>(); TensorHandleFactoryRegistry registry; backendA->RegisterTensorHandleFactories(registry); backendB->RegisterTensorHandleFactories(registry); backendC->RegisterTensorHandleFactories(registry); + backendD->RegisterTensorHandleFactories(registry); BackendsMap backends; backends["BackendA"] = std::move(backendA); backends["BackendB"] = std::move(backendB); backends["BackendC"] = std::move(backendC); + backends["BackendD"] = std::move(backendD); armnn::Graph graph; @@ -226,13 +299,17 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3"); softmaxLayer3->SetBackendId("BackendC"); + armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4"); + softmaxLayer4->SetBackendId("BackendD"); + armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output"); outputLayer->SetBackendId("BackendA"); inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0)); softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0)); softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0)); - softmaxLayer3->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0)); + softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); graph.TopologicalSort(); @@ -246,29 +323,45 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0); OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0); OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0); + OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0); // Check that the correct factory was selected BOOST_TEST(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryA1"); BOOST_TEST(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); BOOST_TEST(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); BOOST_TEST(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1"); + BOOST_TEST(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); // Check that the correct strategy was selected - BOOST_TEST((inputLayerOut.GetMemoryStrategyForConnection(0) == MemoryStrategy::DirectCompatibility)); - BOOST_TEST((softmaxLayer1Out.GetMemoryStrategyForConnection(0) == MemoryStrategy::DirectCompatibility)); - BOOST_TEST((softmaxLayer2Out.GetMemoryStrategyForConnection(0) == MemoryStrategy::CopyToTarget)); - BOOST_TEST((softmaxLayer3Out.GetMemoryStrategyForConnection(0) == MemoryStrategy::DirectCompatibility)); - - graph.AddCopyLayers(backends, registry); - int count= 0; - graph.ForEachLayer([&count](Layer* layer) + BOOST_TEST((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); + BOOST_TEST((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); + BOOST_TEST((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget)); + BOOST_TEST((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget)); + BOOST_TEST((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); + + graph.AddCompatibilityLayers(backends, registry); + + // Test for copy layers + int copyCount= 0; + graph.ForEachLayer([©Count](Layer* layer) { if (layer->GetType() == LayerType::MemCopy) { - count++; + copyCount++; + } + }); + BOOST_TEST(copyCount == 1); + + // Test for import layers + int importCount= 0; + graph.ForEachLayer([&importCount](Layer *layer) + { + if (layer->GetType() == LayerType::MemImport) + { + importCount++; } }); - BOOST_TEST(count == 1); + BOOST_TEST(importCount == 1); } BOOST_AUTO_TEST_SUITE_END() |