diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-12-17 17:27:37 +0000 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2022-01-11 11:59:22 +0000 |
commit | b275da5afe480e994fe6cd897b3090c52f1bcdea (patch) | |
tree | 390246b981d5d39f79099b57a56c0f21c2932c05 /src/backends/backendsCommon | |
parent | 2048bcf8ed671b593ac9af2974e10319b9058b20 (diff) | |
download | armnn-b275da5afe480e994fe6cd897b3090c52f1bcdea.tar.gz |
IVGCVSW-6675 Add functionality for CopyAndImportFactoryPair
to TensorHandleFactoryRegistry
* Add RegisterCopyAndImportFactoryPair
* Add GetMatchingImportFactoryId
* Unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I91e71bdeadec8fedbce7088016c06a47a03bdbaa
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp | 12 | ||||
-rw-r--r-- | src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp | 10 |
2 files changed, 22 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp index cc8a1361a3..8094f04959 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp @@ -63,6 +63,18 @@ ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFacto return nullptr; } +void TensorHandleFactoryRegistry::RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, + ITensorHandleFactory::FactoryId importFactoryId) +{ + m_FactoryMappings[copyFactoryId] = importFactoryId; +} + +ITensorHandleFactory::FactoryId TensorHandleFactoryRegistry::GetMatchingImportFactoryId( + ITensorHandleFactory::FactoryId copyFactoryId) +{ + return m_FactoryMappings[copyFactoryId]; +} + void TensorHandleFactoryRegistry::AquireMemory() { for (auto& mgr : m_MemoryManagers) diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp index 525db56216..2a7c6f36d9 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp @@ -16,6 +16,8 @@ namespace armnn //Forward class IMemoryManager; +using CopyAndImportFactoryPairs = std::map<ITensorHandleFactory::FactoryId, ITensorHandleFactory::FactoryId>; + /// class TensorHandleFactoryRegistry { @@ -39,6 +41,13 @@ public: ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id, MemorySource memSource) const; + /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import + void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, + ITensorHandleFactory::FactoryId importFactoryId); + + /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy + ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId); + /// Aquire memory required for inference void AquireMemory(); @@ -53,6 +62,7 @@ public: private: std::vector<std::unique_ptr<ITensorHandleFactory>> m_Factories; std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers; + CopyAndImportFactoryPairs m_FactoryMappings; }; } // namespace armnn |