diff options
Diffstat (limited to 'src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp')
-rw-r--r-- | src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp index 525db56216..2a7c6f36d9 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp @@ -16,6 +16,8 @@ namespace armnn //Forward class IMemoryManager; +using CopyAndImportFactoryPairs = std::map<ITensorHandleFactory::FactoryId, ITensorHandleFactory::FactoryId>; + /// class TensorHandleFactoryRegistry { @@ -39,6 +41,13 @@ public: ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id, MemorySource memSource) const; + /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import + void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, + ITensorHandleFactory::FactoryId importFactoryId); + + /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy + ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId); + /// Aquire memory required for inference void AquireMemory(); @@ -53,6 +62,7 @@ public: private: std::vector<std::unique_ptr<ITensorHandleFactory>> m_Factories; std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers; + CopyAndImportFactoryPairs m_FactoryMappings; }; } // namespace armnn |