summaryrefslogtreecommitdiff
path: root/source/use_case/inference_runner
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/inference_runner')
-rw-r--r--source/use_case/inference_runner/src/MainLoop.cc3
-rw-r--r--source/use_case/inference_runner/src/UseCaseHandler.cc6
2 files changed, 8 insertions, 1 deletions
diff --git a/source/use_case/inference_runner/src/MainLoop.cc b/source/use_case/inference_runner/src/MainLoop.cc
index b110a24..26a20de 100644
--- a/source/use_case/inference_runner/src/MainLoop.cc
+++ b/source/use_case/inference_runner/src/MainLoop.cc
@@ -38,6 +38,9 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "inference_runner"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
+
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("imgIndex", 0);
diff --git a/source/use_case/inference_runner/src/UseCaseHandler.cc b/source/use_case/inference_runner/src/UseCaseHandler.cc
index ac4ea47..a2b3fb7 100644
--- a/source/use_case/inference_runner/src/UseCaseHandler.cc
+++ b/source/use_case/inference_runner/src/UseCaseHandler.cc
@@ -28,6 +28,7 @@ namespace app {
bool RunInferenceHandler(ApplicationContext& ctx)
{
auto& platform = ctx.Get<hal_platform&>("platform");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
auto& model = ctx.Get<Model&>("model");
constexpr uint32_t dataPsnTxtInfStartX = 150;
@@ -67,7 +68,7 @@ namespace app {
str_inf.c_str(), str_inf.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
- RunInference(platform, model);
+ RunInference(model, profiler);
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
@@ -75,6 +76,9 @@ namespace app {
str_inf.c_str(), str_inf.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ info("Final results:\n");
+ profiler.PrintProfilingResult();
+
#if VERIFY_TEST_OUTPUT
for (size_t outputIndex = 0; outputIndex < model.GetNumOutputs(); outputIndex++) {
arm::app::DumpTensor(model.GetOutputTensor(outputIndex));