diff options
Diffstat (limited to 'src/armnn/WorkingMemHandle.cpp')
-rw-r--r-- | src/armnn/WorkingMemHandle.cpp | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/src/armnn/WorkingMemHandle.cpp b/src/armnn/WorkingMemHandle.cpp index e2ad52a772..2cb47fbfc7 100644 --- a/src/armnn/WorkingMemHandle.cpp +++ b/src/armnn/WorkingMemHandle.cpp @@ -17,16 +17,20 @@ namespace experimental WorkingMemHandle::WorkingMemHandle(NetworkId networkId, std::vector<InputMemDescriptorCoords> inputLayerInfo, - std::vector<OutputMemDescriptorCoords> ouputLayerInfo, + std::vector<OutputMemDescriptorCoords> outputLayerInfo, std::vector<WorkingMemDescriptor> workingMemDescriptors, std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap, - std::vector<std::shared_ptr<IMemoryManager>> memoryManagers, - std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > ownedTensorHandles) + std::unique_ptr<MemoryManager> memoryManager, + std::vector<std::pair<std::shared_ptr<TensorMemory>, MemorySource>> tensorMemory, + std::vector<std::unique_ptr<ITensorHandle>> managedTensorHandles, + std::vector<std::unique_ptr<ITensorHandle>> unmanagedTensorHandles) : m_NetworkId(networkId) , m_WorkingMemDescriptors(workingMemDescriptors) , m_WorkingMemDescriptorMap(workingMemDescriptorMap) - , m_MemoryManagers(memoryManagers) - , m_OwnedTensorHandles(std::move(ownedTensorHandles)) + , m_MemoryManager(std::move(memoryManager)) + , m_TensorMemory(std::move(tensorMemory)) + , m_ManagedTensorHandles(std::move(managedTensorHandles)) + , m_UnmanagedTensorHandles(std::move(unmanagedTensorHandles)) , m_InputSize(numeric_cast<DifferenceType>(inputLayerInfo.size())) , m_IsAllocated(false) { @@ -54,7 +58,7 @@ WorkingMemHandle::WorkingMemHandle(NetworkId networkId, } } size_t bindingIdCount = inputLayerInfo.size(); - for (const auto& outputInfo : ouputLayerInfo) + for (const auto& outputInfo : outputLayerInfo) { for (auto bindingId : outputInfo.m_LayerBindingIds) { @@ -88,6 +92,7 @@ WorkingMemHandle::WorkingMemHandle(NetworkId networkId, } } m_BindingIdVec = std::vector<LayerBindingId>(bindingIdCount); + IgnoreUnused(m_UnmanagedTensorHandles); } void WorkingMemHandle::Allocate() @@ -98,9 +103,11 @@ void WorkingMemHandle::Allocate() } m_IsAllocated = true; - for (auto& mgr : m_MemoryManagers) + m_MemoryManager->Allocate(); + + for (unsigned int i = 0; i < m_TensorMemory.size(); ++i) { - mgr->Acquire(); + m_ManagedTensorHandles[i]->Import(m_TensorMemory[i].first->m_Data, m_TensorMemory[i].second); } } @@ -112,10 +119,7 @@ void WorkingMemHandle::Free() } m_IsAllocated = false; - for (auto& mgr : m_MemoryManagers) - { - mgr->Release(); - } + m_MemoryManager->Deallocate(); } void WorkingMemHandle::MemSyncOutputs() |