aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp27
1 files changed, 26 insertions, 1 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 25ccbee45a..e168923048 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -30,6 +30,7 @@
#include <boost/variant.hpp>
#include <algorithm>
+#include <chrono>
#include <iterator>
#include <fstream>
#include <map>
@@ -506,7 +507,9 @@ public:
return m_OutputBindings[outputIndex].second.GetNumElements();
}
- void Run(const std::vector<TContainer>& inputContainers, std::vector<TContainer>& outputContainers)
+ std::chrono::duration<double, std::milli> Run(
+ const std::vector<TContainer>& inputContainers,
+ std::vector<TContainer>& outputContainers)
{
for (unsigned int i = 0; i < outputContainers.size(); ++i)
{
@@ -532,10 +535,15 @@ public:
profiler->EnableProfiling(m_EnableProfiling);
}
+ // Start timer to record inference time in EnqueueWorkload (in milliseconds)
+ const auto start_time = GetCurrentTime();
+
armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
MakeInputTensors(inputContainers),
MakeOutputTensors(outputContainers));
+ const auto end_time = GetCurrentTime();
+
// if profiling is enabled print out the results
if (profiler && profiler->IsProfilingEnabled())
{
@@ -546,6 +554,10 @@ public:
{
throw armnn::Exception("IRuntime::EnqueueWorkload failed");
}
+ else
+ {
+ return std::chrono::duration<double, std::milli>(end_time - start_time);
+ }
}
const BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const
@@ -613,4 +625,17 @@ private:
{
return ::MakeOutputTensors(m_OutputBindings, outputDataContainers);
}
+
+ std::chrono::high_resolution_clock::time_point GetCurrentTime()
+ {
+ return std::chrono::high_resolution_clock::now();
+ }
+
+ std::chrono::duration<double, std::milli> GetTimeDuration(
+ std::chrono::high_resolution_clock::time_point& start_time,
+ std::chrono::high_resolution_clock::time_point& end_time)
+ {
+ return std::chrono::duration<double, std::milli>(end_time - start_time);
+ }
+
};