summaryrefslogtreecommitdiff
path: root/source/use_case/kws
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws')
-rw-r--r--source/use_case/kws/src/MainLoop.cc3
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc30
2 files changed, 24 insertions, 9 deletions
diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc
index 24cb939..f971c30 100644
--- a/source/use_case/kws/src/MainLoop.cc
+++ b/source/use_case/kws/src/MainLoop.cc
@@ -58,6 +58,9 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "kws"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
+
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index 872d323..d2cba55 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -82,6 +82,7 @@ namespace app {
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
auto& platform = ctx.Get<hal_platform&>("platform");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
@@ -215,7 +216,7 @@ namespace app {
audioDataSlider.TotalStrides() + 1);
/* Run inference over this audio clip sliding window. */
- arm::app::RunInference(platform, model);
+ arm::app::RunInference(model, profiler);
std::vector<ClassificationResult> classificationResult;
auto& classifier = ctx.Get<KwsClassifier&>("classifier");
@@ -243,6 +244,8 @@ namespace app {
return false;
}
+ profiler.PrintProfilingResult();
+
_IncrementAppCtxClipIdx(ctx);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
@@ -281,6 +284,8 @@ namespace app {
constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
platform.data_psn->set_text_color(COLOR_GREEN);
+ info("Final results:\n");
+ info("Total number of inferences: %zu\n", results.size());
/* Display each result */
uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
@@ -290,7 +295,7 @@ namespace app {
std::string topKeyword{"<none>"};
float score = 0.f;
- if (results[i].m_resultVec.size()) {
+ if (!results[i].m_resultVec.empty()) {
topKeyword = results[i].m_resultVec[0].m_label;
score = results[i].m_resultVec[0].m_normalisedVal;
}
@@ -305,13 +310,20 @@ namespace app {
dataPsnTxtStartX1, rowIdx1, false);
rowIdx1 += dataPsnTxtYIncr;
- info("For timestamp: %f (inference #: %u); threshold: %f\n",
- results[i].m_timeStamp, results[i].m_inferenceNumber,
- results[i].m_threshold);
- for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
- info("\t\tlabel @ %u: %s, score: %f\n", j,
- results[i].m_resultVec[j].m_label.c_str(),
- results[i].m_resultVec[j].m_normalisedVal);
+ if (results[i].m_resultVec.empty()) {
+ info("For timestamp: %f (inference #: %u); label: %s; threshold: %f\n",
+ results[i].m_timeStamp, results[i].m_inferenceNumber,
+ topKeyword.c_str(),
+ results[i].m_threshold);
+ } else {
+ for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
+ info("For timestamp: %f (inference #: %u); label: %s, score: %f; threshold: %f\n",
+ results[i].m_timeStamp,
+ results[i].m_inferenceNumber,
+ results[i].m_resultVec[j].m_label.c_str(),
+ results[i].m_resultVec[j].m_normalisedVal,
+ results[i].m_threshold);
+ }
}
}