aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp71
1 files changed, 67 insertions, 4 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index cab594ed48..88c704c10e 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -101,6 +101,7 @@ struct Params
std::string m_CachedNetworkFilePath;
unsigned int m_NumberOfThreads;
std::string m_MLGOTuningFilePath;
+ bool m_AsyncEnabled;
Params()
@@ -118,6 +119,7 @@ struct Params
, m_CachedNetworkFilePath("")
, m_NumberOfThreads(0)
, m_MLGOTuningFilePath("")
+ , m_AsyncEnabled(false)
{}
};
@@ -472,14 +474,14 @@ public:
optNet->SerializeToDot(file);
}
-
-
armnn::Status ret;
{
ARMNN_SCOPED_HEAP_PROFILING("LoadNetwork");
const auto loading_start_time = armnn::GetTimeNow();
- ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet));
+ armnn::INetworkProperties networkProperties(false, false, params.m_AsyncEnabled);
+ std::string errorMessage;
+ ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties);
ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2)
<< std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms\n";
@@ -553,7 +555,6 @@ public:
armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
MakeInputTensors(inputContainers),
MakeOutputTensors(outputContainers));
-
const auto duration = armnn::GetTimeDuration(start_time);
// if profiling is enabled print out the results
@@ -572,6 +573,63 @@ public:
}
}
+ std::tuple<armnn::profiling::ProfilingGuid, std::chrono::duration<double, std::milli>> RunAsync(
+ armnn::experimental::IWorkingMemHandle& workingMemHandleRef,
+ const std::vector<TContainer>& inputContainers,
+ std::vector<TContainer>& outputContainers)
+ {
+ for (unsigned int i = 0; i < outputContainers.size(); ++i)
+ {
+ const unsigned int expectedOutputDataSize = GetOutputSize(i);
+
+ mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value)
+ {
+ const unsigned int actualOutputDataSize = armnn::numeric_cast<unsigned int>(value.size());
+ if (actualOutputDataSize < expectedOutputDataSize)
+ {
+ unsigned int outputIndex = i;
+ throw armnn::Exception(
+ fmt::format("Not enough data for output #{0}: expected "
+ "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize));
+ }
+ },
+ outputContainers[i]);
+ }
+
+ std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
+ if (profiler)
+ {
+ profiler->EnableProfiling(m_EnableProfiling);
+ }
+
+ // Start timer to record inference time in EnqueueWorkload (in milliseconds)
+ const auto start_time = armnn::GetTimeNow();
+
+ armnn::Status ret = m_Runtime->Execute(workingMemHandleRef,
+ MakeInputTensors(inputContainers),
+ MakeOutputTensors(outputContainers));
+ auto inferenceID = workingMemHandleRef.GetInferenceId();
+
+ const auto duration = armnn::GetTimeDuration(start_time);
+
+ // if profiling is enabled print out the results
+ if (profiler && profiler->IsProfilingEnabled())
+ {
+ profiler->Print(std::cout);
+ }
+
+ if (ret == armnn::Status::Failure)
+ {
+ throw armnn::Exception(
+ fmt::format("IRuntime::Execute asynchronously failed for network #{0} on inference #{1}",
+ m_NetworkIdentifier, inferenceID));
+ }
+ else
+ {
+ return std::make_tuple(inferenceID, duration);
+ }
+ }
+
const armnn::BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const
{
CheckInputIndexIsValid(inputIndex);
@@ -618,6 +676,11 @@ public:
return quantizationParams;
}
+ std::unique_ptr<armnn::experimental::IWorkingMemHandle> CreateWorkingMemHandle()
+ {
+ return m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier);
+ }
+
private:
armnn::NetworkId m_NetworkIdentifier;
std::shared_ptr<armnn::IRuntime> m_Runtime;