aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2019-04-11 10:23:58 +0100
committerJames Conroy <james.conroy@arm.com>2019-04-11 16:32:39 +0100
commit7b4886faccb52af9afe7fdeffcbae87e7fbc1484 (patch)
tree5aa8891568bd75b48676ef555f6bece585043a45 /tests/InferenceModel.hpp
parent774f6f1d7c862fc2b8e1783abef9a0bccdaf9d0c (diff)
downloadarmnn-7b4886faccb52af9afe7fdeffcbae87e7fbc1484.tar.gz
IVGCVSW-2543 Add timing for ExecuteNetwork inference
* Adds a new command line option 'threshold-time' to ExecuteNetwork, the maximum allowed time for inference in EnqueueWorkload. * ExecuteNetwork now outputs inference time elapsed and (if supplied) threshold time. * If actual elapsed inference time is greater than supplied threshold time, fail the test. Change-Id: If441b49a29cf5450687c07500c9046a80ece56fc Signed-off-by: James Conroy <james.conroy@arm.com>
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp27
1 files changed, 26 insertions, 1 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 25ccbee45a..e168923048 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -30,6 +30,7 @@
#include <boost/variant.hpp>
#include <algorithm>
+#include <chrono>
#include <iterator>
#include <fstream>
#include <map>
@@ -506,7 +507,9 @@ public:
return m_OutputBindings[outputIndex].second.GetNumElements();
}
- void Run(const std::vector<TContainer>& inputContainers, std::vector<TContainer>& outputContainers)
+ std::chrono::duration<double, std::milli> Run(
+ const std::vector<TContainer>& inputContainers,
+ std::vector<TContainer>& outputContainers)
{
for (unsigned int i = 0; i < outputContainers.size(); ++i)
{
@@ -532,10 +535,15 @@ public:
profiler->EnableProfiling(m_EnableProfiling);
}
+ // Start timer to record inference time in EnqueueWorkload (in milliseconds)
+ const auto start_time = GetCurrentTime();
+
armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
MakeInputTensors(inputContainers),
MakeOutputTensors(outputContainers));
+ const auto end_time = GetCurrentTime();
+
// if profiling is enabled print out the results
if (profiler && profiler->IsProfilingEnabled())
{
@@ -546,6 +554,10 @@ public:
{
throw armnn::Exception("IRuntime::EnqueueWorkload failed");
}
+ else
+ {
+ return std::chrono::duration<double, std::milli>(end_time - start_time);
+ }
}
const BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const
@@ -613,4 +625,17 @@ private:
{
return ::MakeOutputTensors(m_OutputBindings, outputDataContainers);
}
+
+ std::chrono::high_resolution_clock::time_point GetCurrentTime()
+ {
+ return std::chrono::high_resolution_clock::now();
+ }
+
+ std::chrono::duration<double, std::milli> GetTimeDuration(
+ std::chrono::high_resolution_clock::time_point& start_time,
+ std::chrono::high_resolution_clock::time_point& end_time)
+ {
+ return std::chrono::duration<double, std::milli>(end_time - start_time);
+ }
+
};