diff options
Diffstat (limited to 'shim/sl/canonical/ArmnnPreparedModel.cpp')
-rw-r--r-- | shim/sl/canonical/ArmnnPreparedModel.cpp | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/shim/sl/canonical/ArmnnPreparedModel.cpp b/shim/sl/canonical/ArmnnPreparedModel.cpp index 54a019004c..79cd241348 100644 --- a/shim/sl/canonical/ArmnnPreparedModel.cpp +++ b/shim/sl/canonical/ArmnnPreparedModel.cpp @@ -393,7 +393,37 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph( armnn::Status status; VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false"; importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc); + if (!importedInputIds.empty()) + { + // Some or all of the input tensors been imported. We need to remove the ones that could from + // inputTensors. + for (armnn::ImportedInputId& importedId : importedInputIds) + { + inputTensors.erase( + std::remove_if( + inputTensors.begin(), inputTensors.end(), + [&importedId](std::pair<armnn::LayerBindingId, class armnn::ConstTensor>& element) { + return (element.first == static_cast<int>(importedId)); + }), + inputTensors.end()); + } + } importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc); + if (!importedOutputIds.empty()) + { + // Some or all of the output tensors could not be imported. We need to remove the ones that could + // from outputTensors. + for (armnn::ImportedInputId& importedId : importedOutputIds) + { + outputTensors.erase( + std::remove_if( + outputTensors.begin(), outputTensors.end(), + [&importedId](std::pair<armnn::LayerBindingId, class armnn::Tensor>& element) { + return (element.first == static_cast<int>(importedId)); + }), + outputTensors.end()); + } + } status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors, |