summaryrefslogtreecommitdiff
path: root/source/use_case/kws/src
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
commitb40ecf8522052809d2351677a96195d69e4d0c16 (patch)
tree8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/kws/src
parentc291144b7f08c21d08cdaf79cc64dc420ca70070 (diff)
downloadml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/kws/src')
-rw-r--r--source/use_case/kws/src/KwsProcessing.cc53
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc46
2 files changed, 47 insertions, 52 deletions
diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc
index 14f9fce..328709d 100644
--- a/source/use_case/kws/src/KwsProcessing.cc
+++ b/source/use_case/kws/src/KwsProcessing.cc
@@ -22,22 +22,19 @@
namespace arm {
namespace app {
- KWSPreProcess::KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride):
+ KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames,
+ int mfccFrameLength, int mfccFrameStride
+ ):
+ m_inputTensor{inputTensor},
m_mfccFrameLength{mfccFrameLength},
m_mfccFrameStride{mfccFrameStride},
+ m_numMfccFrames{numMfccFrames},
m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)}
{
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
this->m_mfcc.Init();
- TfLiteIntArray* inputShape = model->GetInputShape(0);
- const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
-
/* Deduce the data length required for 1 inference from the network parameters. */
- this->m_audioDataWindowSize = numMfccFrames * this->m_mfccFrameStride +
+ this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride +
(this->m_mfccFrameLength - this->m_mfccFrameStride);
/* Creating an MFCC feature sliding window for the data required for 1 inference. */
@@ -62,7 +59,7 @@ namespace app {
- this->m_numMfccVectorsInAudioStride;
/* Construct feature calculation function. */
- this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_model->GetInputTensor(0),
+ this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor,
this->m_numReusedMfccVectors);
if (!this->m_mfccFeatureCalculator) {
@@ -70,7 +67,7 @@ namespace app {
}
}
- bool KWSPreProcess::DoPreProcess(const void* data, size_t inputSize)
+ bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize)
{
UNUSED(inputSize);
if (data == nullptr) {
@@ -116,8 +113,8 @@ namespace app {
*/
template<class T>
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
- KWSPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
- std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+ KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+ std::function<std::vector<T> (std::vector<int16_t>& )> compute)
{
/* Feature cache to be captured by lambda function. */
static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
@@ -149,18 +146,18 @@ namespace app {
}
template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- KWSPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+ KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
- KWSPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+ KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<float>(std::vector<int16_t>&)> compute);
std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- KWSPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+ KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
{
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
@@ -195,23 +192,19 @@ namespace app {
return mfccFeatureCalc;
}
- KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model,
+ KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results)
- :m_kwsClassifier{classifier},
+ :m_outputTensor{outputTensor},
+ m_kwsClassifier{classifier},
m_labels{labels},
m_results{results}
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ {}
- bool KWSPostProcess::DoPostProcess()
+ bool KwsPostProcess::DoPostProcess()
{
return this->m_kwsClassifier.GetClassificationResults(
- this->m_model->GetOutputTensor(0), this->m_results,
+ this->m_outputTensor, this->m_results,
this->m_labels, 1, true);
}
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index e73a2c3..61c6eb6 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -34,13 +34,12 @@ using KwsClassifier = arm::app::Classifier;
namespace arm {
namespace app {
-
/**
* @brief Presents KWS inference results.
* @param[in] results Vector of KWS classification results to be displayed.
* @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(const std::vector<arm::app::kws::KwsResult>& results);
+ static bool PresentInferenceResult(const std::vector<kws::KwsResult>& results);
/* KWS inference handler. */
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
@@ -50,6 +49,7 @@ namespace app {
const auto mfccFrameLength = ctx.Get<int>("frameLength");
const auto mfccFrameStride = ctx.Get<int>("frameStride");
const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
+
/* If the request has a valid size, set the audio index. */
if (clipIndex < NUMBER_OF_FILES) {
if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
@@ -61,16 +61,17 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
constexpr int minTensorDims = static_cast<int>(
- (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
- arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
-
+ (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)?
+ MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx);
if (!model.IsInited()) {
printf_err("Model is not initialised! Terminating processing.\n");
return false;
}
+ /* Get Input and Output tensors for pre/post processing. */
TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return false;
@@ -81,22 +82,23 @@ namespace app {
/* Get input shape for feature extraction. */
TfLiteIntArray* inputShape = model.GetInputShape(0);
- const uint32_t numMfccFeatures = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx];
+ const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
+ const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
/* We expect to be sampling 1 second worth of data at a time.
* NOTE: This is only used for time stamp calculation. */
const float secondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
/* Set up pre and post-processing. */
- KWSPreProcess preprocess = KWSPreProcess(&model, numMfccFeatures, mfccFrameLength, mfccFrameStride);
+ KwsPreProcess preProcess = KwsPreProcess(inputTensor, numMfccFeatures, numMfccFrames,
+ mfccFrameLength, mfccFrameStride);
std::vector<ClassificationResult> singleInfResult;
- KWSPostProcess postprocess = KWSPostProcess(ctx.Get<KwsClassifier &>("classifier"), &model,
+ KwsPostProcess postProcess = KwsPostProcess(outputTensor, ctx.Get<KwsClassifier &>("classifier"),
ctx.Get<std::vector<std::string>&>("labels"),
singleInfResult);
- UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
-
+ /* Loop to process audio clips. */
do {
hal_lcd_clear(COLOR_BLACK);
@@ -106,7 +108,7 @@ namespace app {
auto audioDataSlider = audio::SlidingWindow<const int16_t>(
get_audio_array(currentIndex),
get_audio_array_size(currentIndex),
- preprocess.m_audioDataWindowSize, preprocess.m_audioDataStride);
+ preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride);
/* Declare a container to hold results from across the whole audio clip. */
std::vector<kws::KwsResult> finalResults;
@@ -123,34 +125,34 @@ namespace app {
const int16_t* inferenceWindow = audioDataSlider.Next();
/* The first window does not have cache ready. */
- preprocess.m_audioWindowIndex = audioDataSlider.Index();
+ preProcess.m_audioWindowIndex = audioDataSlider.Index();
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
audioDataSlider.TotalStrides() + 1);
/* Run the pre-processing, inference and post-processing. */
- if (!runner.PreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
+ if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
+ printf_err("Pre-processing failed.");
return false;
}
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
/* Add results from this window to our final results vector. */
finalResults.emplace_back(kws::KwsResult(singleInfResult,
- audioDataSlider.Index() * secondsPerSample * preprocess.m_audioDataStride,
+ audioDataSlider.Index() * secondsPerSample * preProcess.m_audioDataStride,
audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- arm::app::DumpTensor(outputTensor);
+ DumpTensor(outputTensor);
#endif /* VERIFY_TEST_OUTPUT */
} /* while (audioDataSlider.HasNext()) */
@@ -174,7 +176,7 @@ namespace app {
return true;
}
- static bool PresentInferenceResult(const std::vector<arm::app::kws::KwsResult>& results)
+ static bool PresentInferenceResult(const std::vector<kws::KwsResult>& results)
{
constexpr uint32_t dataPsnTxtStartX1 = 20;
constexpr uint32_t dataPsnTxtStartY1 = 30;
@@ -187,7 +189,7 @@ namespace app {
/* Display each result */
uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
- for (const auto & result : results) {
+ for (const auto& result : results) {
std::string topKeyword{"<none>"};
float score = 0.f;