aboutsummaryrefslogtreecommitdiff
path: root/shim/sl/canonical/ArmnnPreparedModel.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2022-07-19 12:37:20 +0100
committerNikhil Raj <nikhil.raj@arm.com>2022-07-27 15:57:46 +0100
commit1e276f38e67af7505a25010eee579034ee83d12b (patch)
tree48607813d793d4142c0a2e4bc0b0b4cf15cf8285 /shim/sl/canonical/ArmnnPreparedModel.cpp
parent07389192266eedac50a64c7d66ef62c1532e06f2 (diff)
downloadarmnn-1e276f38e67af7505a25010eee579034ee83d12b.tar.gz
IVGCVSW-6954 'Arm NN Support Library Implementation'
* Fixed model converting issue * Fixed import memory issue Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Ied61810b308e0c5d5754f122a6ea2bac1d0725f1
Diffstat (limited to 'shim/sl/canonical/ArmnnPreparedModel.cpp')
-rw-r--r--shim/sl/canonical/ArmnnPreparedModel.cpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/shim/sl/canonical/ArmnnPreparedModel.cpp b/shim/sl/canonical/ArmnnPreparedModel.cpp
index 54a019004c..79cd241348 100644
--- a/shim/sl/canonical/ArmnnPreparedModel.cpp
+++ b/shim/sl/canonical/ArmnnPreparedModel.cpp
@@ -393,7 +393,37 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph(
armnn::Status status;
VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false";
importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
+ if (!importedInputIds.empty())
+ {
+ // Some or all of the input tensors been imported. We need to remove the ones that could from
+ // inputTensors.
+ for (armnn::ImportedInputId& importedId : importedInputIds)
+ {
+ inputTensors.erase(
+ std::remove_if(
+ inputTensors.begin(), inputTensors.end(),
+ [&importedId](std::pair<armnn::LayerBindingId, class armnn::ConstTensor>& element) {
+ return (element.first == static_cast<int>(importedId));
+ }),
+ inputTensors.end());
+ }
+ }
importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
+ if (!importedOutputIds.empty())
+ {
+ // Some or all of the output tensors could not be imported. We need to remove the ones that could
+ // from outputTensors.
+ for (armnn::ImportedInputId& importedId : importedOutputIds)
+ {
+ outputTensors.erase(
+ std::remove_if(
+ outputTensors.begin(), outputTensors.end(),
+ [&importedId](std::pair<armnn::LayerBindingId, class armnn::Tensor>& element) {
+ return (element.first == static_cast<int>(importedId));
+ }),
+ outputTensors.end());
+ }
+ }
status = m_Runtime->EnqueueWorkload(m_NetworkId,
inputTensors,
outputTensors,