aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp32
1 files changed, 27 insertions, 5 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index bddaf557fd..e2a1a97568 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -111,6 +111,7 @@ struct Params
std::string m_MLGOTuningFilePath;
bool m_AsyncEnabled;
size_t m_ThreadPoolSize;
+ bool m_ImportInputsIfAligned;
Params()
@@ -132,6 +133,7 @@ struct Params
, m_MLGOTuningFilePath("")
, m_AsyncEnabled(false)
, m_ThreadPoolSize(0)
+ , m_ImportInputsIfAligned(false)
{}
};
@@ -438,8 +440,9 @@ public:
const std::string& dynamicBackendsPath,
const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
: m_EnableProfiling(enableProfiling),
- m_ProfilingDetailsMethod(armnn::ProfilingDetailsMethod::Undefined)
- , m_DynamicBackendsPath(dynamicBackendsPath)
+ m_ProfilingDetailsMethod(armnn::ProfilingDetailsMethod::Undefined),
+ m_DynamicBackendsPath(dynamicBackendsPath),
+ m_ImportInputsIfAligned(params.m_ImportInputsIfAligned)
{
if (runtime)
{
@@ -612,9 +615,27 @@ public:
// Start timer to record inference time in EnqueueWorkload (in milliseconds)
const auto start_time = armnn::GetTimeNow();
- armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
- MakeInputTensors(inputContainers),
- MakeOutputTensors(outputContainers));
+ armnn::Status ret;
+ if (m_ImportInputsIfAligned)
+ {
+ std::vector<armnn::ImportedInputId> importedInputIds = m_Runtime->ImportInputs(
+ m_NetworkIdentifier, MakeInputTensors(inputContainers), armnn::MemorySource::Malloc);
+
+ std::vector<armnn::ImportedOutputId> importedOutputIds = m_Runtime->ImportOutputs(
+ m_NetworkIdentifier, MakeOutputTensors(outputContainers), armnn::MemorySource::Malloc);
+
+ ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
+ MakeInputTensors(inputContainers),
+ MakeOutputTensors(outputContainers),
+ importedInputIds,
+ importedOutputIds);
+ }
+ else
+ {
+ ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
+ MakeInputTensors(inputContainers),
+ MakeOutputTensors(outputContainers));
+ }
const auto duration = armnn::GetTimeDuration(start_time);
// if profiling is enabled print out the results
@@ -784,6 +805,7 @@ private:
bool m_EnableProfiling;
armnn::ProfilingDetailsMethod m_ProfilingDetailsMethod;
std::string m_DynamicBackendsPath;
+ bool m_ImportInputsIfAligned;
template<typename TContainer>
armnn::InputTensors MakeInputTensors(const std::vector<TContainer>& inputDataContainers)