From b275da5afe480e994fe6cd897b3090c52f1bcdea Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 17 Dec 2021 17:27:37 +0000 Subject: IVGCVSW-6675 Add functionality for CopyAndImportFactoryPair to TensorHandleFactoryRegistry * Add RegisterCopyAndImportFactoryPair * Add GetMatchingImportFactoryId * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: I91e71bdeadec8fedbce7088016c06a47a03bdbaa --- src/armnn/test/TensorHandleStrategyTest.cpp | 14 ++++++++++++++ .../backendsCommon/TensorHandleFactoryRegistry.cpp | 12 ++++++++++++ .../backendsCommon/TensorHandleFactoryRegistry.hpp | 10 ++++++++++ 3 files changed, 36 insertions(+) diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp index fb26880d0c..374479b941 100644 --- a/src/armnn/test/TensorHandleStrategyTest.cpp +++ b/src/armnn/test/TensorHandleStrategyTest.cpp @@ -392,4 +392,18 @@ TEST_CASE("TensorHandleSelectionStrategy") CHECK(importCount == 1); } +TEST_CASE("RegisterCopyAndImportFactoryPairTest") +{ + TensorHandleFactoryRegistry registry; + ITensorHandleFactory::FactoryId copyId = "CopyFactoryId"; + ITensorHandleFactory::FactoryId importId = "ImportFactoryId"; + registry.RegisterCopyAndImportFactoryPair(copyId, importId); + + // Get mathing import factory id correctly + CHECK((registry.GetMatchingImportFactoryId(copyId) == importId)); + + // Return empty id when Invailid Id is given + CHECK((registry.GetMatchingImportFactoryId("InvalidFactoryId") == "")); +} + } diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp index cc8a1361a3..8094f04959 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp @@ -63,6 +63,18 @@ ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFacto return nullptr; } +void TensorHandleFactoryRegistry::RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, + ITensorHandleFactory::FactoryId importFactoryId) +{ + m_FactoryMappings[copyFactoryId] = importFactoryId; +} + +ITensorHandleFactory::FactoryId TensorHandleFactoryRegistry::GetMatchingImportFactoryId( + ITensorHandleFactory::FactoryId copyFactoryId) +{ + return m_FactoryMappings[copyFactoryId]; +} + void TensorHandleFactoryRegistry::AquireMemory() { for (auto& mgr : m_MemoryManagers) diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp index 525db56216..2a7c6f36d9 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp @@ -16,6 +16,8 @@ namespace armnn //Forward class IMemoryManager; +using CopyAndImportFactoryPairs = std::map; + /// class TensorHandleFactoryRegistry { @@ -39,6 +41,13 @@ public: ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id, MemorySource memSource) const; + /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import + void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, + ITensorHandleFactory::FactoryId importFactoryId); + + /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy + ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId); + /// Aquire memory required for inference void AquireMemory(); @@ -53,6 +62,7 @@ public: private: std::vector> m_Factories; std::vector> m_MemoryManagers; + CopyAndImportFactoryPairs m_FactoryMappings; }; } // namespace armnn -- cgit v1.2.1