summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiam Barry <liam.barry@arm.com>2021-12-30 11:35:00 +0000
committerLiam Barry <liam.barry@arm.com>2022-01-04 12:17:03 +0000
commitb5b32d3e6188cc7126b3181e3be328d6583c5967 (patch)
tree557b5ec9476a1a8bde291e6e500fd614ea5e3667
parentb48b5b6e7c2943965d9ae7370485fd85081e4dbe (diff)
downloadml-embedded-evaluation-kit-b5b32d3e6188cc7126b3181e3be328d6583c5967.tar.gz
MLECO-2835: Remove magic number for ASR-KWS
Replaced ctx.set/get<uint32>(keywordindex) with keyword itself as const std::string& Change-Id: I1811d93548105d6db58e57b88675f9b41e66d914 Signed-off-by: Liam Barry <liam.barry@arm.com>
-rw-r--r--source/use_case/kws_asr/src/MainLoop.cc11
-rw-r--r--source/use_case/kws_asr/src/UseCaseHandler.cc2
2 files changed, 10 insertions, 3 deletions
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc
index 30cb084..c7e977f 100644
--- a/source/use_case/kws_asr/src/MainLoop.cc
+++ b/source/use_case/kws_asr/src/MainLoop.cc
@@ -136,8 +136,15 @@ void main_loop(hal_platform& platform)
caseContext.Set<const std::vector <std::string>&>("asrlabels", asrLabels);
caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels);
- /* Index of the kws outputs we trigger ASR on. */
- caseContext.Set<uint32_t>("keywordindex", 9 );
+ /* KWS keyword that triggers ASR and associated checks */
+ std::string triggerKeyword = std::string("yes");
+ if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) {
+ caseContext.Set<const std::string &>("triggerkeyword", triggerKeyword);
+ }
+ else {
+ printf_err("Selected trigger keyword not found in labels file\n");
+ return;
+ }
/* Loop. */
bool executionSuccessful = true;
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index c67be22..a3ebdb1 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -240,7 +240,7 @@ namespace app {
);
/* Keyword detected. */
- if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
+ if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) {
output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
output.asrAudioSamples = get_audio_array_size(currentIndex) -
(audioDataSlider.NextWindowStartIndex() -