diff options
author | Richard Burton <richard.burton@arm.com> | 2022-10-05 11:00:37 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-10-06 14:08:13 +0100 |
commit | ec5e99be3ae6dd0d3811950f155b01e144431452 (patch) | |
tree | a5d6c4dd9267db2465063b8d0e1a5cb6d19dac8d /source/use_case | |
parent | 890b2b89cacc6f2291596a001d555d374c8c9edd (diff) | |
download | ml-embedded-evaluation-kit-ec5e99be3ae6dd0d3811950f155b01e144431452.tar.gz |
MLECO-3164: Additional refactoring of KWS API
Part 1
* Add KwsClassifier
* KwsPostProcess can now be told to average results
* Averaging is handlded by KwsClassifier
* Current sliding window index is now an argument of DoPreProcess
Change-Id: I07626da595ad1cbd982e8366f0d1bb56d1040459
Diffstat (limited to 'source/use_case')
-rw-r--r-- | source/use_case/kws/src/MainLoop.cc | 9 | ||||
-rw-r--r-- | source/use_case/kws/src/UseCaseHandler.cc | 9 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/MainLoop.cc | 8 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/UseCaseHandler.cc | 5 |
4 files changed, 10 insertions, 21 deletions
diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc index e0518f2..2489df8 100644 --- a/source/use_case/kws/src/MainLoop.cc +++ b/source/use_case/kws/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ * limitations under the License. */ #include "InputFiles.hpp" /* For input audio clips. */ -#include "Classifier.hpp" /* Classifier. */ +#include "KwsClassifier.hpp" /* Classifier. */ #include "MicroNetKwsModel.hpp" /* Model class for running inference. */ #include "hal.h" /* Brings in platform definitions. */ #include "Labels.hpp" /* For label strings. */ @@ -34,7 +34,6 @@ namespace app { } /* namespace app */ } /* namespace arm */ -using KwsClassifier = arm::app::Classifier; enum opcodes { @@ -83,8 +82,8 @@ void main_loop() caseContext.Set<int>("frameStride", arm::app::kws::g_FrameStride); caseContext.Set<float>("scoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ - KwsClassifier classifier; /* classifier wrapper object. */ - caseContext.Set<arm::app::Classifier&>("classifier", classifier); + arm::app::KwsClassifier classifier; /* classifier wrapper object. */ + caseContext.Set<arm::app::KwsClassifier&>("classifier", classifier); std::vector <std::string> labels; GetLabelsVector(labels); diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 61c6eb6..d61ba9d 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -17,7 +17,7 @@ #include "UseCaseHandler.hpp" #include "InputFiles.hpp" -#include "Classifier.hpp" +#include "KwsClassifier.hpp" #include "MicroNetKwsModel.hpp" #include "hal.h" #include "AudioUtils.hpp" @@ -29,8 +29,6 @@ #include <vector> -using KwsClassifier = arm::app::Classifier; - namespace arm { namespace app { @@ -124,14 +122,11 @@ namespace app { while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - /* The first window does not have cache ready. */ - preProcess.m_audioWindowIndex = audioDataSlider.Index(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run the pre-processing, inference and post-processing. */ - if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) { printf_err("Pre-processing failed."); return false; } diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index 0638ecd..a4f7db9 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -17,7 +17,7 @@ #include "InputFiles.hpp" /* For input images. */ #include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ -#include "Classifier.hpp" /* KWS classifier. */ +#include "KwsClassifier.hpp" /* KWS classifier. */ #include "AsrClassifier.hpp" /* ASR classifier. */ #include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ @@ -42,8 +42,6 @@ namespace app { } /* namespace app */ } /* namespace arm */ -using KwsClassifier = arm::app::Classifier; - enum opcodes { MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ @@ -118,9 +116,9 @@ void main_loop() caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride); caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ - KwsClassifier kwsClassifier; /* Classifier wrapper object. */ + arm::app::KwsClassifier kwsClassifier; /* Classifier wrapper object. */ arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */ - caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier); + caseContext.Set<arm::app::KwsClassifier&>("kwsClassifier", kwsClassifier); caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier); std::vector<std::string> asrLabels; diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 9427ae0..c5e6ad3 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -143,11 +143,8 @@ namespace app { while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - /* The first window does not have cache ready. */ - preProcess.m_audioWindowIndex = audioDataSlider.Index(); - /* Run the pre-processing, inference and post-processing. */ - if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) { printf_err("KWS Pre-processing failed."); return output; } |