summaryrefslogtreecommitdiff
path: root/source/use_case/ad
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2021-04-07 17:15:31 +0100
committerAlexander Efremov <alexander.efremov@arm.com>2021-04-12 14:00:49 +0000
commit8df12f37531d57a10cba2f8b2e8b6a9065202dd5 (patch)
treeba833d15649c3b0f885d57b40d3916970b3fd2c8 /source/use_case/ad
parent37ce22ebc9cf3e8529d9914c0eed0f718243d961 (diff)
downloadml-embedded-evaluation-kit-8df12f37531d57a10cba2f8b2e8b6a9065202dd5.tar.gz
MLECO-1870: Cherry pick profiling changes from dev to open source repo
* Documentation update Change-Id: If85e7ebc44498840b291c408f14e66a5a5faa424 Signed-off-by: Isabella Gottardi <isabella.gottardi@arm.com>
Diffstat (limited to 'source/use_case/ad')
-rw-r--r--source/use_case/ad/src/MainLoop.cc2
-rw-r--r--source/use_case/ad/src/UseCaseHandler.cc7
2 files changed, 7 insertions, 2 deletions
diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc
index 5455b43..6a7cbe0 100644
--- a/source/use_case/ad/src/MainLoop.cc
+++ b/source/use_case/ad/src/MainLoop.cc
@@ -56,6 +56,8 @@ void main_loop(hal_platform& platform)
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
+ arm::app::Profiler profiler{&platform, "ad"};
+ caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc
index c18a0a4..1c15595 100644
--- a/source/use_case/ad/src/UseCaseHandler.cc
+++ b/source/use_case/ad/src/UseCaseHandler.cc
@@ -76,6 +76,7 @@ namespace app {
bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
auto& platform = ctx.Get<hal_platform&>("platform");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
@@ -98,7 +99,7 @@ namespace app {
const auto frameLength = ctx.Get<int>("frameLength");
const auto frameStride = ctx.Get<int>("frameStride");
const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
- const float trainingMean = ctx.Get<float>("trainingMean");
+ const auto trainingMean = ctx.Get<float>("trainingMean");
auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
TfLiteTensor* outputTensor = model.GetOutputTensor(0);
@@ -193,7 +194,7 @@ namespace app {
audioDataSlider.TotalStrides() + 1);
/* Run inference over this audio clip sliding window */
- arm::app::RunInference(platform, model);
+ arm::app::RunInference(model, profiler);
/* Use the negative softmax score of the corresponding index as the outlier score */
std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor);
@@ -219,6 +220,8 @@ namespace app {
return false;
}
+ profiler.PrintProfilingResult();
+
_IncrementAppCtxClipIdx(ctx);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);