aboutsummaryrefslogtreecommitdiff
path: root/shim/sl/canonical/ArmnnPreparedModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'shim/sl/canonical/ArmnnPreparedModel.cpp')
-rw-r--r--shim/sl/canonical/ArmnnPreparedModel.cpp43
1 files changed, 17 insertions, 26 deletions
diff --git a/shim/sl/canonical/ArmnnPreparedModel.cpp b/shim/sl/canonical/ArmnnPreparedModel.cpp
index c0ce3e41a1..54a019004c 100644
--- a/shim/sl/canonical/ArmnnPreparedModel.cpp
+++ b/shim/sl/canonical/ArmnnPreparedModel.cpp
@@ -93,21 +93,21 @@ bool IsPointerTypeMemory(const Request& request)
{
for (auto& input : request.inputs)
{
- if (input.lifetime == Request::Argument::LifeTime::POINTER)
+ if (input.lifetime != Request::Argument::LifeTime::POINTER)
{
- return true;
+ return false;
}
}
for (auto& output: request.outputs)
{
- if (output.lifetime == Request::Argument::LifeTime::POINTER)
+ if (output.lifetime != Request::Argument::LifeTime::POINTER)
{
- return true;
+ return false;
}
}
- return false;
+ return true;
}
} // anonymous namespace
@@ -318,7 +318,8 @@ ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> ArmnnPreparedModel:
}
VLOG(DRIVER) << "ArmnnPreparedModel::execute(): " << GetModelSummary(m_Model).c_str();
}
- if (hasDeadlinePassed(deadline)) {
+ if (hasDeadlinePassed(deadline))
+ {
return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
}
@@ -381,7 +382,8 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph(
VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph(...)";
DumpTensorsIfRequired("Input", inputTensors);
-
+ std::vector<armnn::ImportedInputId> importedInputIds;
+ std::vector<armnn::ImportedOutputId> importedOutputIds;
try
{
if (ctx.measureTimings == MeasureTiming::YES)
@@ -390,24 +392,13 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph(
}
armnn::Status status;
VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false";
-
- if (pointerMemory)
- {
- std::vector<armnn::ImportedInputId> importedInputIds;
- importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
-
- std::vector<armnn::ImportedOutputId> importedOutputIds;
- importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
- status = m_Runtime->EnqueueWorkload(m_NetworkId,
- inputTensors,
- outputTensors,
- importedInputIds,
- importedOutputIds);
- }
- else
- {
- status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
- }
+ importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
+ importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
+ status = m_Runtime->EnqueueWorkload(m_NetworkId,
+ inputTensors,
+ outputTensors,
+ importedInputIds,
+ importedOutputIds);
if (ctx.measureTimings == MeasureTiming::YES)
{
@@ -430,7 +421,7 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph(
return ErrorStatus::GENERAL_FAILURE;
}
- if (!pointerMemory)
+ if (!pointerMemory && (!importedInputIds.empty() || !importedOutputIds.empty()))
{
CommitPools(*pMemPools);
}