aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/WorkingMemHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/WorkingMemHandle.hpp')
-rw-r--r--src/armnn/WorkingMemHandle.hpp59
1 files changed, 13 insertions, 46 deletions
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp
index cef6fb6fd3..92b0acaec3 100644
--- a/src/armnn/WorkingMemHandle.hpp
+++ b/src/armnn/WorkingMemHandle.hpp
@@ -26,10 +26,12 @@ class WorkingMemHandle final : public IWorkingMemHandle
public:
WorkingMemHandle(NetworkId networkId,
std::vector<WorkingMemDescriptor> workingMemDescriptors,
- std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap);
+ std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap,
+ std::vector<std::shared_ptr<IMemoryManager>> memoryManagers,
+ std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > ownedTensorHandles);
~WorkingMemHandle()
- { FreeWorkingMemory(); }
+ { Free(); }
NetworkId GetNetworkId() override
{
@@ -38,50 +40,10 @@ public:
/// Allocate the backing memory required for execution. If this is not called, then allocation will be
/// deferred to execution time. The mutex must be locked.
- void Allocate() override
- {
- if (m_IsAllocated)
- {
- return;
- }
- m_IsAllocated = true;
-
- // Iterate through all WorkingMemDescriptors calling allocate() on each input and output in turn
- for (auto workingMemDescriptor : m_WorkingMemDescriptors)
- {
- for (auto& input : workingMemDescriptor.m_Inputs)
- {
- input->Allocate();
- }
- for (auto& output : workingMemDescriptor.m_Outputs)
- {
- output->Allocate();
- }
- }
- }
+ void Allocate() override;
/// Free the backing memory required for execution. The mutex must be locked.
- void Free() override
- {
- if (!m_IsAllocated)
- {
- return;
- }
- m_IsAllocated = false;
-
- // Iterate through all WorkingMemDescriptors calling free() on each input and output in turn
- for (auto workingMemDescriptor : m_WorkingMemDescriptors)
- {
- for (auto& input : workingMemDescriptor.m_Inputs)
- {
- input->Unmap();
- }
- for (auto& output : workingMemDescriptor.m_Outputs)
- {
- output->Unmap();
- }
- }
- }
+ void Free() override;
/// IsAllocated returns true if the backing memory is currently allocated. The mutex must be locked.
bool IsAllocated() override
@@ -111,13 +73,18 @@ public:
}
private:
- void FreeWorkingMemory();
-
NetworkId m_NetworkId;
std::shared_ptr<ProfilerImpl> m_Profiler;
std::vector<WorkingMemDescriptor> m_WorkingMemDescriptors;
std::unordered_map<LayerGuid, WorkingMemDescriptor> m_WorkingMemDescriptorMap;
+
+ // Vector of IMemoryManagers that manage the WorkingMemHandle's memory
+ std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers;
+ // TensorHandles owned by this WorkingMemHandle
+ // constant tensor's can be shared by multiple WorkingMemHandles and so will not be stored here
+ std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > m_OwnedTensorHandles;
+
bool m_IsAllocated;
std::mutex m_Mutex;
};