diff options
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 30 |
1 files changed, 23 insertions, 7 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 45891f7dc3..48a3040b23 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -699,7 +699,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, for (const BindableLayer* inputLayer : graph.GetInputLayers()) { - if (preImportedInputIds.size() != m_PreImportedInputHandles.size()) + if (preImportedInputIds.size() > graph.GetNumInputs()) { throw InvalidArgumentException("Invalid number of preImportedInputIds"); } @@ -727,7 +727,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, for (const BindableLayer* outputLayer : graph.GetOutputLayers()) { - if (preImportedOutputIds.size() != m_PreImportedOutputHandles.size()) + if (preImportedOutputIds.size() > graph.GetNumOutputs()) { throw InvalidArgumentException("Invalid number of preImportedOutputIds"); } @@ -770,11 +770,6 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, } } } - // Clear m_PreImportedInputHandles and m_PreImportedOutputHandles - m_PreImportedInputHandles.clear(); - m_PreImportedOutputHandles.clear(); - m_CurImportedInputId = 0; - m_CurImportedOutputId = 0; std::unique_ptr<TimelineUtilityMethods> timelineUtils = TimelineUtilityMethods::GetTimelineUtils(m_ProfilingService); @@ -1271,6 +1266,16 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp { // Cannot import, use allocated data handler.UseAllocatedData(); + // Ensure that the workload get correct tensor + try + { + m_WorkloadQueue[m_InputWorkloadSlotPairs[layerBindingId].first].get()->ReplaceInputTensorHandle( + handler.GetData(), m_InputWorkloadSlotPairs[layerBindingId].second); + } + catch(armnn::UnimplementedException& e) + { + IgnoreUnused(e); + } } } @@ -1437,6 +1442,17 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors& { // Cannot import, use allocated memory outputHandler.UseAllocatedData(); + // Ensure that the workload get correct tensor + try + { + m_WorkloadQueue[m_OutputWorkloadSlotPairs[layerBindingId].first].get()-> + ReplaceOutputTensorHandle(outputHandler.GetData(), + m_OutputWorkloadSlotPairs[layerBindingId].second); + } + catch(armnn::UnimplementedException& e) + { + IgnoreUnused(e); + } } } return importedOutputs; |