diff options
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 52 |
1 files changed, 38 insertions, 14 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 7ee4e612e0..5e3e3f24fe 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -132,7 +132,18 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net, } default: { - layer->CreateTensorHandles(m_TensorHandleFactoryRegistry, workloadFactory); + // Look for the layer with 1 OutputSlot which has 1 connection and that connection is an Output Layer + // If Export is enabled disable memory management so we can export, otherwise we do a copy + if((layer->GetNumOutputSlots() == 1) && + (layer->GetOutputSlots()[0].GetNumConnections() == 1) && + (layer->GetOutputSlots()[0].GetConnection(0)->GetOwningLayer().GetType() == LayerType::Output)) + { + layer->CreateTensorHandles(m_TensorHandleFactoryRegistry, workloadFactory, !m_IsExportEnabled); + } + else + { + layer->CreateTensorHandles(m_TensorHandleFactoryRegistry, workloadFactory); + } } } } @@ -409,17 +420,24 @@ void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tens info.m_OutputTensorInfos.push_back(outputTensorInfo); MemorySourceFlags importFlags = outputTensorHandle->GetImportFlags(); - if (CheckFlag(importFlags, MemorySource::Malloc) && m_IsImportEnabled) // Try import the input tensor + if (m_IsImportEnabled) // Try import the input tensor { - // This assumes a CPU Tensor handle - void* mem = tensorHandle->Map(false); - if (outputTensorHandle->Import(mem, MemorySource::Malloc)) + if(CheckFlag(importFlags, MemorySource::Malloc) ) { + // This assumes a CPU Tensor handle + void* mem = tensorHandle->Map(false); + if (outputTensorHandle->Import(mem, MemorySource::Malloc)) + { + tensorHandle->Unmap(); + return; // No need for a workload since the import has been done. + } tensorHandle->Unmap(); - return; // No need for a workload since the import has been done. + throw MemoryImportException("EnqueueInput: Memory Import failed"); + } + else + { + throw MemoryImportException("EnqueueInput: Memory Import failed, backend does not support Import"); } - tensorHandle->Unmap(); - throw MemoryImportException("EnqueueInput: Memory Import failed"); } else { @@ -464,10 +482,10 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten // b) The tensor has zero padding // c) There is only one connection to the OutputSlot and it is to an OutputLayer. // d) The output pointer is allocated via malloc. (Other types will be supported in a later release) - if (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer().GetType() != LayerType::Input - && m_IsExportEnabled) + // e) m_IsExportEnabled must be set to true + if (m_IsExportEnabled && (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetNumConnections() == 1)) { - if (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetNumConnections() == 1) + if(layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer().GetType() != LayerType::Input) { MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags(); if (CheckFlag(importFlags, MemorySource::Malloc)) @@ -485,19 +503,25 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info); BOOST_ASSERT_MSG(syncWorkload, "No sync workload created"); m_OutputQueue.push_back(move(syncWorkload)); - - return; //No need to add the output workload below } else { throw MemoryExportException("EnqueueOutput: Memory Export failed"); } } + else + { + throw MemoryExportException("EnqueueOutput: Memory Export failed, backend does not support Export"); + } + } + else + { + throw MemoryExportException("EnqueueOutput: Memory Export failed, attempting to export Input Layer"); } } else { - // If we got here then we couldn't import the memory, so add an output workload which performs a memcopy. + // If we got here then we didn't export the memory, so add an output workload which performs a memcopy. outputQueueDescriptor.m_Inputs.push_back(inputTensorHandle); info.m_InputTensorInfos.push_back(inputTensorInfo); |