summaryrefslogtreecommitdiff
path: root/source/use_case
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2021-04-07 17:15:31 +0100
committerAlexander Efremov <alexander.efremov@arm.com>2021-04-12 14:00:49 +0000
commit8df12f37531d57a10cba2f8b2e8b6a9065202dd5 (patch)
treeba833d15649c3b0f885d57b40d3916970b3fd2c8 /source/use_case
parent37ce22ebc9cf3e8529d9914c0eed0f718243d961 (diff)
downloadml-embedded-evaluation-kit-8df12f37531d57a10cba2f8b2e8b6a9065202dd5.tar.gz
MLECO-1870: Cherry pick profiling changes from dev to open source repo
* Documentation update Change-Id: If85e7ebc44498840b291c408f14e66a5a5faa424 Signed-off-by: Isabella Gottardi <isabella.gottardi@arm.com>
Diffstat (limited to 'source/use_case')
-rw-r--r--source/use_case/ad/src/MainLoop.cc2
-rw-r--r--source/use_case/ad/src/UseCaseHandler.cc7
-rw-r--r--source/use_case/asr/src/MainLoop.cc2
-rw-r--r--source/use_case/asr/src/UseCaseHandler.cc22
-rw-r--r--source/use_case/img_class/src/MainLoop.cc2
-rw-r--r--source/use_case/img_class/src/UseCaseHandler.cc7
-rw-r--r--source/use_case/inference_runner/src/MainLoop.cc3
-rw-r--r--source/use_case/inference_runner/src/UseCaseHandler.cc6
-rw-r--r--source/use_case/kws/src/MainLoop.cc3
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc30
-rw-r--r--source/use_case/kws_asr/src/MainLoop.cc3
-rw-r--r--source/use_case/kws_asr/src/UseCaseHandler.cc17
12 files changed, 71 insertions, 33 deletions
diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc
index 5455b43..6a7cbe0 100644
--- a/source/use_case/ad/src/MainLoop.cc
+++ b/source/use_case/ad/src/MainLoop.cc
@@ -56,6 +56,8 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "ad"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc
index c18a0a4..1c15595 100644
--- a/source/use_case/ad/src/UseCaseHandler.cc
+++ b/source/use_case/ad/src/UseCaseHandler.cc
@@ -76,6 +76,7 @@ namespace app {
bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
auto& platform = ctx.Get<hal_platform&>("platform");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
@@ -98,7 +99,7 @@ namespace app {
const auto frameLength = ctx.Get<int>("frameLength");
const auto frameStride = ctx.Get<int>("frameStride");
const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
- const float trainingMean = ctx.Get<float>("trainingMean");
+ const auto trainingMean = ctx.Get<float>("trainingMean");
auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
TfLiteTensor* outputTensor = model.GetOutputTensor(0);
@@ -193,7 +194,7 @@ namespace app {
audioDataSlider.TotalStrides() + 1);
/* Run inference over this audio clip sliding window */
- arm::app::RunInference(platform, model);
+ arm::app::RunInference(model, profiler);
/* Use the negative softmax score of the corresponding index as the outlier score */
std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor);
@@ -219,6 +220,8 @@ namespace app {
return false;
}
+ profiler.PrintProfilingResult();
+
_IncrementAppCtxClipIdx(ctx);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
diff --git a/source/use_case/asr/src/MainLoop.cc b/source/use_case/asr/src/MainLoop.cc
index ca777be..c5a26a4 100644
--- a/source/use_case/asr/src/MainLoop.cc
+++ b/source/use_case/asr/src/MainLoop.cc
@@ -96,6 +96,8 @@ void main_loop(hal_platform& platform)
GetLabelsVector(labels);
arm::app::AsrClassifier classifier; /* Classifier wrapper object. */
+ arm::app::Profiler profiler{&platform, "asr"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index e706eb8..5d3157a 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -67,6 +67,8 @@ namespace app {
auto& platform = ctx.Get<hal_platform&>("platform");
platform.data_psn->clear(COLOR_BLACK);
+ auto& profiler = ctx.Get<Profiler&>("profiler");
+
/* If the request has a valid size, set the audio index. */
if (clipIndex < NUMBER_OF_FILES) {
if (!_SetAppCtxClipIdx(ctx, clipIndex)) {
@@ -168,18 +170,11 @@ namespace app {
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
- Profiler prepProfiler{&platform, "pre-processing"};
- prepProfiler.StartProfiling();
-
/* Calculate MFCCs, deltas and populate the input tensor. */
prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor);
- prepProfiler.StopProfiling();
- std::string prepProfileResults = prepProfiler.GetResultsAndReset();
- info("%s\n", prepProfileResults.c_str());
-
/* Run inference over this audio clip sliding window. */
- arm::app::RunInference(platform, model);
+ arm::app::RunInference(model, profiler);
/* Post-process. */
postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext());
@@ -216,6 +211,8 @@ namespace app {
return false;
}
+ profiler.PrintProfilingResult();
+
_IncrementAppCtxClipIdx(ctx);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
@@ -256,6 +253,8 @@ namespace app {
platform.data_psn->set_text_color(COLOR_GREEN);
+ info("Final results:\n");
+ info("Total number of inferences: %zu\n", results.size());
/* Results from multiple inferences should be combined before processing. */
std::vector<arm::app::ClassificationResult> combinedResults;
for (auto& result : results) {
@@ -268,8 +267,9 @@ namespace app {
for (const auto & result : results) {
std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
- info("Result for inf %u: %s\n", result.m_inferenceNumber,
- infResultStr.c_str());
+ info("For timestamp: %f (inference #: %u); label: %s\n",
+ result.m_timeStamp, result.m_inferenceNumber,
+ infResultStr.c_str());
}
/* Get the decoded result for the combined result. */
@@ -280,7 +280,7 @@ namespace app {
dataPsnTxtStartX1, dataPsnTxtStartY1,
allow_multiple_lines);
- info("Final result: %s\n", finalResultStr.c_str());
+ info("Complete recognition: %s\n", finalResultStr.c_str());
return true;
}
diff --git a/source/use_case/img_class/src/MainLoop.cc b/source/use_case/img_class/src/MainLoop.cc
index 469907c..66d7064 100644
--- a/source/use_case/img_class/src/MainLoop.cc
+++ b/source/use_case/img_class/src/MainLoop.cc
@@ -58,6 +58,8 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "img_class"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("imgIndex", 0);
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc
index a412fec..f7e83f5 100644
--- a/source/use_case/img_class/src/UseCaseHandler.cc
+++ b/source/use_case/img_class/src/UseCaseHandler.cc
@@ -74,6 +74,7 @@ namespace app {
bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
{
auto& platform = ctx.Get<hal_platform&>("platform");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
constexpr uint32_t dataPsnImgDownscaleFactor = 2;
constexpr uint32_t dataPsnImgStartX = 10;
@@ -144,7 +145,7 @@ namespace app {
info("Running inference on image %u => %s\n", ctx.Get<uint32_t>("imgIndex"),
get_filename(ctx.Get<uint32_t>("imgIndex")));
- RunInference(platform, model);
+ RunInference(model, profiler);
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
@@ -167,6 +168,8 @@ namespace app {
return false;
}
+ profiler.PrintProfilingResult();
+
_IncrementAppCtxImageIdx(ctx);
} while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
@@ -230,6 +233,8 @@ namespace app {
uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
uint32_t rowIdx2 = dataPsnTxtStartY2;
+ info("Final results:\n");
+ info("Total number of inferences: 1\n");
for (uint32_t i = 0; i < results.size(); ++i) {
std::string resultStr =
std::to_string(i + 1) + ") " +
diff --git a/source/use_case/inference_runner/src/MainLoop.cc b/source/use_case/inference_runner/src/MainLoop.cc
index b110a24..26a20de 100644
--- a/source/use_case/inference_runner/src/MainLoop.cc
+++ b/source/use_case/inference_runner/src/MainLoop.cc
@@ -38,6 +38,9 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "inference_runner"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
+
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("imgIndex", 0);
diff --git a/source/use_case/inference_runner/src/UseCaseHandler.cc b/source/use_case/inference_runner/src/UseCaseHandler.cc
index ac4ea47..a2b3fb7 100644
--- a/source/use_case/inference_runner/src/UseCaseHandler.cc
+++ b/source/use_case/inference_runner/src/UseCaseHandler.cc
@@ -28,6 +28,7 @@ namespace app {
bool RunInferenceHandler(ApplicationContext& ctx)
{
auto& platform = ctx.Get<hal_platform&>("platform");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
auto& model = ctx.Get<Model&>("model");
constexpr uint32_t dataPsnTxtInfStartX = 150;
@@ -67,7 +68,7 @@ namespace app {
str_inf.c_str(), str_inf.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
- RunInference(platform, model);
+ RunInference(model, profiler);
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
@@ -75,6 +76,9 @@ namespace app {
str_inf.c_str(), str_inf.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ info("Final results:\n");
+ profiler.PrintProfilingResult();
+
#if VERIFY_TEST_OUTPUT
for (size_t outputIndex = 0; outputIndex < model.GetNumOutputs(); outputIndex++) {
arm::app::DumpTensor(model.GetOutputTensor(outputIndex));
diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc
index 24cb939..f971c30 100644
--- a/source/use_case/kws/src/MainLoop.cc
+++ b/source/use_case/kws/src/MainLoop.cc
@@ -58,6 +58,9 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "kws"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
+
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index 872d323..d2cba55 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -82,6 +82,7 @@ namespace app {
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
auto& platform = ctx.Get<hal_platform&>("platform");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
@@ -215,7 +216,7 @@ namespace app {
audioDataSlider.TotalStrides() + 1);
/* Run inference over this audio clip sliding window. */
- arm::app::RunInference(platform, model);
+ arm::app::RunInference(model, profiler);
std::vector<ClassificationResult> classificationResult;
auto& classifier = ctx.Get<KwsClassifier&>("classifier");
@@ -243,6 +244,8 @@ namespace app {
return false;
}
+ profiler.PrintProfilingResult();
+
_IncrementAppCtxClipIdx(ctx);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
@@ -281,6 +284,8 @@ namespace app {
constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
platform.data_psn->set_text_color(COLOR_GREEN);
+ info("Final results:\n");
+ info("Total number of inferences: %zu\n", results.size());
/* Display each result */
uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
@@ -290,7 +295,7 @@ namespace app {
std::string topKeyword{"<none>"};
float score = 0.f;
- if (results[i].m_resultVec.size()) {
+ if (!results[i].m_resultVec.empty()) {
topKeyword = results[i].m_resultVec[0].m_label;
score = results[i].m_resultVec[0].m_normalisedVal;
}
@@ -305,13 +310,20 @@ namespace app {
dataPsnTxtStartX1, rowIdx1, false);
rowIdx1 += dataPsnTxtYIncr;
- info("For timestamp: %f (inference #: %u); threshold: %f\n",
- results[i].m_timeStamp, results[i].m_inferenceNumber,
- results[i].m_threshold);
- for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
- info("\t\tlabel @ %u: %s, score: %f\n", j,
- results[i].m_resultVec[j].m_label.c_str(),
- results[i].m_resultVec[j].m_normalisedVal);
+ if (results[i].m_resultVec.empty()) {
+ info("For timestamp: %f (inference #: %u); label: %s; threshold: %f\n",
+ results[i].m_timeStamp, results[i].m_inferenceNumber,
+ topKeyword.c_str(),
+ results[i].m_threshold);
+ } else {
+ for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
+ info("For timestamp: %f (inference #: %u); label: %s, score: %f; threshold: %f\n",
+ results[i].m_timeStamp,
+ results[i].m_inferenceNumber,
+ results[i].m_resultVec[j].m_label.c_str(),
+ results[i].m_resultVec[j].m_normalisedVal,
+ results[i].m_threshold);
+ }
}
}
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc
index 37146c9..95e5a8f 100644
--- a/source/use_case/kws_asr/src/MainLoop.cc
+++ b/source/use_case/kws_asr/src/MainLoop.cc
@@ -101,6 +101,9 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "kws_asr"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
+
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("kwsmodel", kwsModel);
caseContext.Set<arm::app::Model&>("asrmodel", asrModel);
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index c50796f..a428210 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -127,6 +127,7 @@ namespace app {
KWSOutput output;
+ auto& profiler = ctx.Get<Profiler&>("profiler");
auto& kwsModel = ctx.Get<Model&>("kwsmodel");
if (!kwsModel.IsInited()) {
printf_err("KWS model has not been initialised\n");
@@ -243,7 +244,7 @@ namespace app {
audioDataSlider.TotalStrides() + 1);
/* Run inference over this audio clip sliding window. */
- arm::app::RunInference(platform, kwsModel);
+ arm::app::RunInference(kwsModel, profiler);
std::vector<ClassificationResult> kwsClassificationResult;
auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
@@ -284,6 +285,8 @@ namespace app {
return output;
}
+ profiler.PrintProfilingResult();
+
output.executionSuccess = true;
return output;
}
@@ -300,6 +303,7 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
+ auto& profiler = ctx.Get<Profiler&>("profiler");
auto& platform = ctx.Get<hal_platform&>("platform");
platform.data_psn->clear(COLOR_BLACK);
@@ -389,18 +393,11 @@ namespace app {
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
- Profiler prepProfiler{&platform, "pre-processing"};
- prepProfiler.StartProfiling();
-
/* Calculate MFCCs, deltas and populate the input tensor. */
asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
- prepProfiler.StopProfiling();
- std::string prepProfileResults = prepProfiler.GetResultsAndReset();
- info("%s\n", prepProfileResults.c_str());
-
/* Run inference over this audio clip sliding window. */
- arm::app::RunInference(platform, asrModel);
+ arm::app::RunInference(asrModel, profiler);
/* Post-process. */
asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
@@ -432,6 +429,8 @@ namespace app {
return false;
}
+ profiler.PrintProfilingResult();
+
return true;
}