ArmNN
 21.08
TensorHandleFactoryRegistry.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 namespace armnn
10 {
11 
12 void TensorHandleFactoryRegistry::RegisterFactory(std::unique_ptr <ITensorHandleFactory> newFactory)
13 {
14  if (!newFactory)
15  {
16  return;
17  }
18 
19  ITensorHandleFactory::FactoryId id = newFactory->GetId();
20 
21  // Don't register duplicates
22  for (auto& registeredFactory : m_Factories)
23  {
24  if (id == registeredFactory->GetId())
25  {
26  return;
27  }
28  }
29 
30  // Take ownership of the new allocator
31  m_Factories.push_back(std::move(newFactory));
32 }
33 
34 void TensorHandleFactoryRegistry::RegisterMemoryManager(std::shared_ptr<armnn::IMemoryManager> memoryManger)
35 {
36  m_MemoryManagers.push_back(memoryManger);
37 }
38 
40 {
41  for (auto& factory : m_Factories)
42  {
43  if (factory->GetId() == id)
44  {
45  return factory.get();
46  }
47  }
48 
49  return nullptr;
50 }
51 
53  MemorySource memSource) const
54 {
55  for (auto& factory : m_Factories)
56  {
57  if (factory->GetId() == id && factory->GetImportFlags() == static_cast<MemorySourceFlags>(memSource))
58  {
59  return factory.get();
60  }
61  }
62 
63  return nullptr;
64 }
65 
67 {
68  for (auto& mgr : m_MemoryManagers)
69  {
70  mgr->Acquire();
71  }
72 }
73 
75 {
76  for (auto& mgr : m_MemoryManagers)
77  {
78  mgr->Release();
79  }
80 }
81 
82 } // namespace armnn
void RegisterMemoryManager(std::shared_ptr< IMemoryManager > memoryManger)
Register a memory manager with shared ownership.
void RegisterFactory(std::unique_ptr< ITensorHandleFactory > allocator)
Register a TensorHandleFactory and transfer ownership.
void AquireMemory()
Aquire memory required for inference.
unsigned int MemorySourceFlags
Copyright (c) 2021 ARM Limited and Contributors.
void ReleaseMemory()
Release memory required for inference.
MemorySource
Define the Memory Source to reduce copies.
Definition: Types.hpp:198
ITensorHandleFactory * GetFactory(ITensorHandleFactory::FactoryId id) const
Find a TensorHandleFactory by Id Returns nullptr if not found.