diff options
Diffstat (limited to 'src/armnn/WorkingMemHandle.hpp')
-rw-r--r-- | src/armnn/WorkingMemHandle.hpp | 59 |
1 files changed, 13 insertions, 46 deletions
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp index cef6fb6fd3..92b0acaec3 100644 --- a/src/armnn/WorkingMemHandle.hpp +++ b/src/armnn/WorkingMemHandle.hpp @@ -26,10 +26,12 @@ class WorkingMemHandle final : public IWorkingMemHandle public: WorkingMemHandle(NetworkId networkId, std::vector<WorkingMemDescriptor> workingMemDescriptors, - std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap); + std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap, + std::vector<std::shared_ptr<IMemoryManager>> memoryManagers, + std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > ownedTensorHandles); ~WorkingMemHandle() - { FreeWorkingMemory(); } + { Free(); } NetworkId GetNetworkId() override { @@ -38,50 +40,10 @@ public: /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. - void Allocate() override - { - if (m_IsAllocated) - { - return; - } - m_IsAllocated = true; - - // Iterate through all WorkingMemDescriptors calling allocate() on each input and output in turn - for (auto workingMemDescriptor : m_WorkingMemDescriptors) - { - for (auto& input : workingMemDescriptor.m_Inputs) - { - input->Allocate(); - } - for (auto& output : workingMemDescriptor.m_Outputs) - { - output->Allocate(); - } - } - } + void Allocate() override; /// Free the backing memory required for execution. The mutex must be locked. - void Free() override - { - if (!m_IsAllocated) - { - return; - } - m_IsAllocated = false; - - // Iterate through all WorkingMemDescriptors calling free() on each input and output in turn - for (auto workingMemDescriptor : m_WorkingMemDescriptors) - { - for (auto& input : workingMemDescriptor.m_Inputs) - { - input->Unmap(); - } - for (auto& output : workingMemDescriptor.m_Outputs) - { - output->Unmap(); - } - } - } + void Free() override; /// IsAllocated returns true if the backing memory is currently allocated. The mutex must be locked. bool IsAllocated() override @@ -111,13 +73,18 @@ public: } private: - void FreeWorkingMemory(); - NetworkId m_NetworkId; std::shared_ptr<ProfilerImpl> m_Profiler; std::vector<WorkingMemDescriptor> m_WorkingMemDescriptors; std::unordered_map<LayerGuid, WorkingMemDescriptor> m_WorkingMemDescriptorMap; + + // Vector of IMemoryManagers that manage the WorkingMemHandle's memory + std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers; + // TensorHandles owned by this WorkingMemHandle + // constant tensor's can be shared by multiple WorkingMemHandles and so will not be stored here + std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > m_OwnedTensorHandles; + bool m_IsAllocated; std::mutex m_Mutex; }; |