diff options
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp | 12 | ||||
-rw-r--r-- | src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp | 10 |
2 files changed, 22 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp index cc8a1361a3..8094f04959 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp @@ -63,6 +63,18 @@ ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFacto return nullptr; } +void TensorHandleFactoryRegistry::RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, + ITensorHandleFactory::FactoryId importFactoryId) +{ + m_FactoryMappings[copyFactoryId] = importFactoryId; +} + +ITensorHandleFactory::FactoryId TensorHandleFactoryRegistry::GetMatchingImportFactoryId( + ITensorHandleFactory::FactoryId copyFactoryId) +{ + return m_FactoryMappings[copyFactoryId]; +} + void TensorHandleFactoryRegistry::AquireMemory() { for (auto& mgr : m_MemoryManagers) diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp index 525db56216..2a7c6f36d9 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp @@ -16,6 +16,8 @@ namespace armnn //Forward class IMemoryManager; +using CopyAndImportFactoryPairs = std::map<ITensorHandleFactory::FactoryId, ITensorHandleFactory::FactoryId>; + /// class TensorHandleFactoryRegistry { @@ -39,6 +41,13 @@ public: ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id, MemorySource memSource) const; + /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import + void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, + ITensorHandleFactory::FactoryId importFactoryId); + + /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy + ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId); + /// Aquire memory required for inference void AquireMemory(); @@ -53,6 +62,7 @@ public: private: std::vector<std::unique_ptr<ITensorHandleFactory>> m_Factories; std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers; + CopyAndImportFactoryPairs m_FactoryMappings; }; } // namespace armnn |