summaryrefslogtreecommitdiff
path: root/source/use_case
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-10-05 11:00:37 +0100
committerRichard Burton <richard.burton@arm.com>2022-10-06 14:08:13 +0100
commitec5e99be3ae6dd0d3811950f155b01e144431452 (patch)
treea5d6c4dd9267db2465063b8d0e1a5cb6d19dac8d /source/use_case
parent890b2b89cacc6f2291596a001d555d374c8c9edd (diff)
downloadml-embedded-evaluation-kit-ec5e99be3ae6dd0d3811950f155b01e144431452.tar.gz
MLECO-3164: Additional refactoring of KWS API
Part 1 * Add KwsClassifier * KwsPostProcess can now be told to average results * Averaging is handlded by KwsClassifier * Current sliding window index is now an argument of DoPreProcess Change-Id: I07626da595ad1cbd982e8366f0d1bb56d1040459
Diffstat (limited to 'source/use_case')
-rw-r--r--source/use_case/kws/src/MainLoop.cc9
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc9
-rw-r--r--source/use_case/kws_asr/src/MainLoop.cc8
-rw-r--r--source/use_case/kws_asr/src/UseCaseHandler.cc5
4 files changed, 10 insertions, 21 deletions
diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc
index e0518f2..2489df8 100644
--- a/source/use_case/kws/src/MainLoop.cc
+++ b/source/use_case/kws/src/MainLoop.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,7 +15,7 @@
* limitations under the License.
*/
#include "InputFiles.hpp" /* For input audio clips. */
-#include "Classifier.hpp" /* Classifier. */
+#include "KwsClassifier.hpp" /* Classifier. */
#include "MicroNetKwsModel.hpp" /* Model class for running inference. */
#include "hal.h" /* Brings in platform definitions. */
#include "Labels.hpp" /* For label strings. */
@@ -34,7 +34,6 @@ namespace app {
} /* namespace app */
} /* namespace arm */
-using KwsClassifier = arm::app::Classifier;
enum opcodes
{
@@ -83,8 +82,8 @@ void main_loop()
caseContext.Set<int>("frameStride", arm::app::kws::g_FrameStride);
caseContext.Set<float>("scoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
- KwsClassifier classifier; /* classifier wrapper object. */
- caseContext.Set<arm::app::Classifier&>("classifier", classifier);
+ arm::app::KwsClassifier classifier; /* classifier wrapper object. */
+ caseContext.Set<arm::app::KwsClassifier&>("classifier", classifier);
std::vector <std::string> labels;
GetLabelsVector(labels);
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index 61c6eb6..d61ba9d 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -17,7 +17,7 @@
#include "UseCaseHandler.hpp"
#include "InputFiles.hpp"
-#include "Classifier.hpp"
+#include "KwsClassifier.hpp"
#include "MicroNetKwsModel.hpp"
#include "hal.h"
#include "AudioUtils.hpp"
@@ -29,8 +29,6 @@
#include <vector>
-using KwsClassifier = arm::app::Classifier;
-
namespace arm {
namespace app {
@@ -124,14 +122,11 @@ namespace app {
while (audioDataSlider.HasNext()) {
const int16_t* inferenceWindow = audioDataSlider.Next();
- /* The first window does not have cache ready. */
- preProcess.m_audioWindowIndex = audioDataSlider.Index();
-
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
audioDataSlider.TotalStrides() + 1);
/* Run the pre-processing, inference and post-processing. */
- if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
+ if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) {
printf_err("Pre-processing failed.");
return false;
}
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc
index 0638ecd..a4f7db9 100644
--- a/source/use_case/kws_asr/src/MainLoop.cc
+++ b/source/use_case/kws_asr/src/MainLoop.cc
@@ -17,7 +17,7 @@
#include "InputFiles.hpp" /* For input images. */
#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */
#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
-#include "Classifier.hpp" /* KWS classifier. */
+#include "KwsClassifier.hpp" /* KWS classifier. */
#include "AsrClassifier.hpp" /* ASR classifier. */
#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */
#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
@@ -42,8 +42,6 @@ namespace app {
} /* namespace app */
} /* namespace arm */
-using KwsClassifier = arm::app::Classifier;
-
enum opcodes
{
MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */
@@ -118,9 +116,9 @@ void main_loop()
caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride);
caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */
- KwsClassifier kwsClassifier; /* Classifier wrapper object. */
+ arm::app::KwsClassifier kwsClassifier; /* Classifier wrapper object. */
arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */
- caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier);
+ caseContext.Set<arm::app::KwsClassifier&>("kwsClassifier", kwsClassifier);
caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier);
std::vector<std::string> asrLabels;
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index 9427ae0..c5e6ad3 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -143,11 +143,8 @@ namespace app {
while (audioDataSlider.HasNext()) {
const int16_t* inferenceWindow = audioDataSlider.Next();
- /* The first window does not have cache ready. */
- preProcess.m_audioWindowIndex = audioDataSlider.Index();
-
/* Run the pre-processing, inference and post-processing. */
- if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
+ if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) {
printf_err("KWS Pre-processing failed.");
return output;
}