aboutsummaryrefslogtreecommitdiff
path: root/shim/sl/canonical/ArmnnPreparedModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'shim/sl/canonical/ArmnnPreparedModel.cpp')
-rw-r--r--shim/sl/canonical/ArmnnPreparedModel.cpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/shim/sl/canonical/ArmnnPreparedModel.cpp b/shim/sl/canonical/ArmnnPreparedModel.cpp
index 54a019004c..79cd241348 100644
--- a/shim/sl/canonical/ArmnnPreparedModel.cpp
+++ b/shim/sl/canonical/ArmnnPreparedModel.cpp
@@ -393,7 +393,37 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph(
armnn::Status status;
VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false";
importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
+ if (!importedInputIds.empty())
+ {
+ // Some or all of the input tensors been imported. We need to remove the ones that could from
+ // inputTensors.
+ for (armnn::ImportedInputId& importedId : importedInputIds)
+ {
+ inputTensors.erase(
+ std::remove_if(
+ inputTensors.begin(), inputTensors.end(),
+ [&importedId](std::pair<armnn::LayerBindingId, class armnn::ConstTensor>& element) {
+ return (element.first == static_cast<int>(importedId));
+ }),
+ inputTensors.end());
+ }
+ }
importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
+ if (!importedOutputIds.empty())
+ {
+ // Some or all of the output tensors could not be imported. We need to remove the ones that could
+ // from outputTensors.
+ for (armnn::ImportedInputId& importedId : importedOutputIds)
+ {
+ outputTensors.erase(
+ std::remove_if(
+ outputTensors.begin(), outputTensors.end(),
+ [&importedId](std::pair<armnn::LayerBindingId, class armnn::Tensor>& element) {
+ return (element.first == static_cast<int>(importedId));
+ }),
+ outputTensors.end());
+ }
+ }
status = m_Runtime->EnqueueWorkload(m_NetworkId,
inputTensors,
outputTensors,