diff options
Diffstat (limited to 'source/use_case/kws_asr/src/MainLoop.cc')
-rw-r--r-- | source/use_case/kws_asr/src/MainLoop.cc | 125 |
1 files changed, 36 insertions, 89 deletions
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index 5c1d0e0..f1d97a0 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "InputFiles.hpp" /* For input images. */ #include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ @@ -24,8 +23,6 @@ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ -#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */ -#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */ #include "log_macros.h" using KwsClassifier = arm::app::Classifier; @@ -53,19 +50,8 @@ static void DisplayMenu() fflush(stdout); } -/** @brief Gets the number of MFCC features for a single window. */ -static uint32_t GetNumMfccFeatures(const arm::app::Model& model); - -/** @brief Gets the number of MFCC feature vectors to be computed. */ -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); - -/** @brief Gets the output context length (left and right) for post-processing. */ -static uint32_t GetOutputContextLen(const arm::app::Model& model, - uint32_t inputCtxLen); - -/** @brief Gets the output inner length for post-processing. */ -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - uint32_t outputCtxLen); +/** @brief Verify input and output tensor are of certain min dimensions. */ +static bool VerifyTensorDimensions(const arm::app::Model& model); void main_loop() { @@ -84,61 +70,46 @@ void main_loop() if (!asrModel.Init(kwsModel.GetAllocator())) { printf_err("Failed to initialise ASR model\n"); return; + } else if (!VerifyTensorDimensions(asrModel)) { + printf_err("Model's input or output dimension verification failed\n"); + return; } - /* Initialise ASR pre-processing. */ - arm::app::audio::asr::Preprocess prep( - GetNumMfccFeatures(asrModel), - arm::app::asr::g_FrameLength, - arm::app::asr::g_FrameStride, - GetNumMfccFeatureVectors(asrModel)); - - /* Initialise ASR post-processing. */ - const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen); - const uint32_t blankTokenIdx = 28; - arm::app::audio::asr::Postprocess postp( - outputCtxLen, - GetOutputInnerLen(asrModel, outputCtxLen), - blankTokenIdx); - /* Instantiate application context. */ arm::app::ApplicationContext caseContext; arm::app::Profiler profiler{"kws_asr"}; caseContext.Set<arm::app::Profiler&>("profiler", profiler); - caseContext.Set<arm::app::Model&>("kwsmodel", kwsModel); - caseContext.Set<arm::app::Model&>("asrmodel", asrModel); + caseContext.Set<arm::app::Model&>("kwsModel", kwsModel); + caseContext.Set<arm::app::Model&>("asrModel", asrModel); caseContext.Set<uint32_t>("clipIndex", 0); caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */ - caseContext.Set<int>("kwsframeLength", arm::app::kws::g_FrameLength); - caseContext.Set<int>("kwsframeStride", arm::app::kws::g_FrameStride); - caseContext.Set<float>("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength); + caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride); + caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc); caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins); - caseContext.Set<int>("asrframeLength", arm::app::asr::g_FrameLength); - caseContext.Set<int>("asrframeStride", arm::app::asr::g_FrameStride); - caseContext.Set<float>("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength); + caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride); + caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ KwsClassifier kwsClassifier; /* Classifier wrapper object. */ arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */ - caseContext.Set<arm::app::Classifier&>("kwsclassifier", kwsClassifier); - caseContext.Set<arm::app::AsrClassifier&>("asrclassifier", asrClassifier); - - caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep); - caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp); + caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier); + caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier); std::vector<std::string> asrLabels; arm::app::asr::GetLabelsVector(asrLabels); std::vector<std::string> kwsLabels; arm::app::kws::GetLabelsVector(kwsLabels); - caseContext.Set<const std::vector <std::string>&>("asrlabels", asrLabels); - caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels); + caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels); + caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels); /* KWS keyword that triggers ASR and associated checks */ - std::string triggerKeyword = std::string("yes"); + std::string triggerKeyword = std::string("no"); if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) { - caseContext.Set<const std::string &>("triggerkeyword", triggerKeyword); + caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword); } else { printf_err("Selected trigger keyword not found in labels file\n"); @@ -196,50 +167,26 @@ void main_loop() info("Main loop terminated.\n"); } -static uint32_t GetNumMfccFeatures(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; - if (0 != inputCols % 3) { - printf_err("Number of input columns is not a multiple of 3\n"); - } - return std::max(inputCols/3, 0); -} - -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) +static bool VerifyTensorDimensions(const arm::app::Model& model) { + /* Populate tensor related parameters. */ TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - return std::max(inputRows, 0); -} - -static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) -{ - const uint32_t inputRows = GetNumMfccFeatureVectors(model); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - - /* Check to make sure that the input tensor supports the above context and inner lengths. */ - if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { - printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", - inputCtxLen); - return 0; + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; } TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - - const float tensorColRatio = static_cast<float>(inputRows)/ - static_cast<float>(outputRows); - - return std::round(static_cast<float>(inputCtxLen)/tensorColRatio); -} + if (!outputTensor->dims) { + printf_err("Invalid output tensor dims\n"); + return false; + } else if (outputTensor->dims->size < 3) { + printf_err("Output tensor dimension should be >= 3\n"); + return false; + } -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - const uint32_t outputCtxLen) -{ - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - return (outputRows - (2 * outputCtxLen)); + return true; } |