diff options
author | Richard Burton <richard.burton@arm.com> | 2022-05-04 09:45:02 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-05-04 09:45:02 +0100 |
commit | 4e002791bc6781b549c6951cfe44f918289d7e82 (patch) | |
tree | b639243b5fa433657c207783a384bad1ed248536 /source/use_case/kws_asr/src | |
parent | dd6d07b24bbf9023ebe8e8927be8aac3291d0f58 (diff) | |
download | ml-embedded-evaluation-kit-4e002791bc6781b549c6951cfe44f918289d7e82.tar.gz |
MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's
Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37
Diffstat (limited to 'source/use_case/kws_asr/src')
-rw-r--r-- | source/use_case/kws_asr/src/KwsProcessing.cc | 212 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/MainLoop.cc | 125 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/UseCaseHandler.cc | 492 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/Wav2LetterPostprocess.cc | 146 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/Wav2LetterPreprocess.cc | 116 |
5 files changed, 548 insertions, 543 deletions
diff --git a/source/use_case/kws_asr/src/KwsProcessing.cc b/source/use_case/kws_asr/src/KwsProcessing.cc new file mode 100644 index 0000000..328709d --- /dev/null +++ b/source/use_case/kws_asr/src/KwsProcessing.cc @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "KwsProcessing.hpp" +#include "ImageUtils.hpp" +#include "log_macros.h" +#include "MicroNetKwsModel.hpp" + +namespace arm { +namespace app { + + KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames, + int mfccFrameLength, int mfccFrameStride + ): + m_inputTensor{inputTensor}, + m_mfccFrameLength{mfccFrameLength}, + m_mfccFrameStride{mfccFrameStride}, + m_numMfccFrames{numMfccFrames}, + m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)} + { + this->m_mfcc.Init(); + + /* Deduce the data length required for 1 inference from the network parameters. */ + this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride + + (this->m_mfccFrameLength - this->m_mfccFrameStride); + + /* Creating an MFCC feature sliding window for the data required for 1 inference. */ + this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(nullptr, this->m_audioDataWindowSize, + this->m_mfccFrameLength, this->m_mfccFrameStride); + + /* For longer audio clips we choose to move by half the audio window size + * => for a 1 second window size there is an overlap of 0.5 seconds. */ + this->m_audioDataStride = this->m_audioDataWindowSize / 2; + + /* To have the previously calculated features re-usable, stride must be multiple + * of MFCC features window stride. Reduce stride through audio if needed. */ + if (0 != this->m_audioDataStride % this->m_mfccFrameStride) { + this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride; + } + + this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride; + + /* Calculate number of the feature vectors in the window overlap region. + * These feature vectors will be reused.*/ + this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1 + - this->m_numMfccVectorsInAudioStride; + + /* Construct feature calculation function. */ + this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor, + this->m_numReusedMfccVectors); + + if (!this->m_mfccFeatureCalculator) { + printf_err("Feature calculator not initialized."); + } + } + + bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize) + { + UNUSED(inputSize); + if (data == nullptr) { + printf_err("Data pointer is null"); + } + + /* Set the features sliding window to the new address. */ + auto input = static_cast<const int16_t*>(data); + this->m_mfccSlidingWindow.Reset(input); + + /* Cache is only usable if we have more than 1 inference in an audio clip. */ + bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0; + + /* Use a sliding window to calculate MFCC features frame by frame. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); + + std::vector<int16_t> mfccFrameAudioData = std::vector<int16_t>(mfccWindow, + mfccWindow + this->m_mfccFrameLength); + + /* Compute features for this window and write them to input tensor. */ + this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(), + useCache, this->m_numMfccVectorsInAudioStride); + } + + debug("Input tensor populated \n"); + + return true; + } + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T Feature vector type. + * @param[in] inputTensor Model input tensor pointer. + * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap. + * @param[in] compute Features calculator function. + * @return Lambda function to compute features. + */ + template<class T> + std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function<std::vector<T> (std::vector<int16_t>& )> compute) + { + /* Feature cache to be captured by lambda function. */ + static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); + + return [=](std::vector<int16_t>& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex) + { + T* tensorData = tflite::GetTensorData<T>(inputTensor); + std::vector<T> features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size(); + auto sizeBytes = sizeof(T) * size; + std::memcpy(tensorData + (index * size), features.data(), sizeBytes); + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex) { + featureCache[index - featuresOverlapIndex] = std::move(features); + } + }; + } + + template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> + KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); + + template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<float>(std::vector<int16_t>&)> compute); + + + std::function<void (std::vector<int16_t>&, int, bool, size_t)> + KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + { + std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + auto *quantParams = (TfLiteAffineQuantization *) quant.params; + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + mfccFeatureCalc = this->FeatureCalc<int8_t>(inputTensor, + cacheSize, + [=, &mfcc](std::vector<int16_t>& audioDataWindow) { + return mfcc.MfccComputeQuant<int8_t>(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + } else { + mfccFeatureCalc = this->FeatureCalc<float>(inputTensor, cacheSize, + [&mfcc](std::vector<int16_t>& audioDataWindow) { + return mfcc.MfccCompute(audioDataWindow); } + ); + } + return mfccFeatureCalc; + } + + KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, + const std::vector<std::string>& labels, + std::vector<ClassificationResult>& results) + :m_outputTensor{outputTensor}, + m_kwsClassifier{classifier}, + m_labels{labels}, + m_results{results} + {} + + bool KwsPostProcess::DoPostProcess() + { + return this->m_kwsClassifier.GetClassificationResults( + this->m_outputTensor, this->m_results, + this->m_labels, 1, true); + } + +} /* namespace app */ +} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index 5c1d0e0..f1d97a0 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "InputFiles.hpp" /* For input images. */ #include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ @@ -24,8 +23,6 @@ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ -#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */ -#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */ #include "log_macros.h" using KwsClassifier = arm::app::Classifier; @@ -53,19 +50,8 @@ static void DisplayMenu() fflush(stdout); } -/** @brief Gets the number of MFCC features for a single window. */ -static uint32_t GetNumMfccFeatures(const arm::app::Model& model); - -/** @brief Gets the number of MFCC feature vectors to be computed. */ -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); - -/** @brief Gets the output context length (left and right) for post-processing. */ -static uint32_t GetOutputContextLen(const arm::app::Model& model, - uint32_t inputCtxLen); - -/** @brief Gets the output inner length for post-processing. */ -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - uint32_t outputCtxLen); +/** @brief Verify input and output tensor are of certain min dimensions. */ +static bool VerifyTensorDimensions(const arm::app::Model& model); void main_loop() { @@ -84,61 +70,46 @@ void main_loop() if (!asrModel.Init(kwsModel.GetAllocator())) { printf_err("Failed to initialise ASR model\n"); return; + } else if (!VerifyTensorDimensions(asrModel)) { + printf_err("Model's input or output dimension verification failed\n"); + return; } - /* Initialise ASR pre-processing. */ - arm::app::audio::asr::Preprocess prep( - GetNumMfccFeatures(asrModel), - arm::app::asr::g_FrameLength, - arm::app::asr::g_FrameStride, - GetNumMfccFeatureVectors(asrModel)); - - /* Initialise ASR post-processing. */ - const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen); - const uint32_t blankTokenIdx = 28; - arm::app::audio::asr::Postprocess postp( - outputCtxLen, - GetOutputInnerLen(asrModel, outputCtxLen), - blankTokenIdx); - /* Instantiate application context. */ arm::app::ApplicationContext caseContext; arm::app::Profiler profiler{"kws_asr"}; caseContext.Set<arm::app::Profiler&>("profiler", profiler); - caseContext.Set<arm::app::Model&>("kwsmodel", kwsModel); - caseContext.Set<arm::app::Model&>("asrmodel", asrModel); + caseContext.Set<arm::app::Model&>("kwsModel", kwsModel); + caseContext.Set<arm::app::Model&>("asrModel", asrModel); caseContext.Set<uint32_t>("clipIndex", 0); caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */ - caseContext.Set<int>("kwsframeLength", arm::app::kws::g_FrameLength); - caseContext.Set<int>("kwsframeStride", arm::app::kws::g_FrameStride); - caseContext.Set<float>("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength); + caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride); + caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc); caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins); - caseContext.Set<int>("asrframeLength", arm::app::asr::g_FrameLength); - caseContext.Set<int>("asrframeStride", arm::app::asr::g_FrameStride); - caseContext.Set<float>("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength); + caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride); + caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ KwsClassifier kwsClassifier; /* Classifier wrapper object. */ arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */ - caseContext.Set<arm::app::Classifier&>("kwsclassifier", kwsClassifier); - caseContext.Set<arm::app::AsrClassifier&>("asrclassifier", asrClassifier); - - caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep); - caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp); + caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier); + caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier); std::vector<std::string> asrLabels; arm::app::asr::GetLabelsVector(asrLabels); std::vector<std::string> kwsLabels; arm::app::kws::GetLabelsVector(kwsLabels); - caseContext.Set<const std::vector <std::string>&>("asrlabels", asrLabels); - caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels); + caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels); + caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels); /* KWS keyword that triggers ASR and associated checks */ - std::string triggerKeyword = std::string("yes"); + std::string triggerKeyword = std::string("no"); if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) { - caseContext.Set<const std::string &>("triggerkeyword", triggerKeyword); + caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword); } else { printf_err("Selected trigger keyword not found in labels file\n"); @@ -196,50 +167,26 @@ void main_loop() info("Main loop terminated.\n"); } -static uint32_t GetNumMfccFeatures(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; - if (0 != inputCols % 3) { - printf_err("Number of input columns is not a multiple of 3\n"); - } - return std::max(inputCols/3, 0); -} - -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) +static bool VerifyTensorDimensions(const arm::app::Model& model) { + /* Populate tensor related parameters. */ TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - return std::max(inputRows, 0); -} - -static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) -{ - const uint32_t inputRows = GetNumMfccFeatureVectors(model); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - - /* Check to make sure that the input tensor supports the above context and inner lengths. */ - if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { - printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", - inputCtxLen); - return 0; + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; } TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - - const float tensorColRatio = static_cast<float>(inputRows)/ - static_cast<float>(outputRows); - - return std::round(static_cast<float>(inputCtxLen)/tensorColRatio); -} + if (!outputTensor->dims) { + printf_err("Invalid output tensor dims\n"); + return false; + } else if (outputTensor->dims->size < 3) { + printf_err("Output tensor dimension should be >= 3\n"); + return false; + } -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - const uint32_t outputCtxLen) -{ - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - return (outputRows - (2 * outputCtxLen)); + return true; } diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 1e1a400..01aefae 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -28,6 +28,7 @@ #include "Wav2LetterMfcc.hpp" #include "Wav2LetterPreprocess.hpp" #include "Wav2LetterPostprocess.hpp" +#include "KwsProcessing.hpp" #include "AsrResult.hpp" #include "AsrClassifier.hpp" #include "OutputDecode.hpp" @@ -39,11 +40,6 @@ using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { - enum AsrOutputReductionAxis { - AxisRow = 1, - AxisCol = 2 - }; - struct KWSOutput { bool executionSuccess = false; const int16_t* asrAudioStart = nullptr; @@ -51,73 +47,53 @@ namespace app { }; /** - * @brief Presents kws inference results using the data presentation - * object. - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents KWS inference results. + * @param[in] results Vector of KWS classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results); + static bool PresentInferenceResult(std::vector<kws::KwsResult>& results); /** - * @brief Presents asr inference results using the data presentation - * object. - * @param[in] platform reference to the hal platform object - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results); + static bool PresentInferenceResult(std::vector<asr::AsrResult>& results); /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MFCC data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] mfcc MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). - * - * @return function function to be called providing audio sample and sliding window index. + * @brief Performs the KWS pipeline. + * @param[in,out] ctx pointer to the application context object + * @return struct containing pointer to audio data where ASR should begin + * and how much data to process. **/ - static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, - TfLiteTensor* inputTensor, - size_t cacheSize); + static KWSOutput doKws(ApplicationContext& ctx) + { + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& kwsModel = ctx.Get<Model&>("kwsModel"); + const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength"); + const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride"); + const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold"); + + auto currentIndex = ctx.Get<uint32_t>("clipIndex"); - /** - * @brief Performs the KWS pipeline. - * @param[in,out] ctx pointer to the application context object - * - * @return KWSOutput struct containing pointer to audio data where ASR should begin - * and how much data to process. - */ - static KWSOutput doKws(ApplicationContext& ctx) { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? - arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); + (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? + MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); - KWSOutput output; + /* Output struct from doing KWS. */ + KWSOutput output {}; - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& kwsModel = ctx.Get<Model&>("kwsmodel"); if (!kwsModel.IsInited()) { printf_err("KWS model has not been initialised\n"); return output; } - const int kwsFrameLength = ctx.Get<int>("kwsframeLength"); - const int kwsFrameStride = ctx.Get<int>("kwsframeStride"); - const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold"); - - TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); - + TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); if (!kwsInputTensor->dims) { printf_err("Invalid input tensor dims\n"); return output; @@ -126,63 +102,32 @@ namespace app { return output; } - const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc"); - const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins"); - - audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength); - kwsMfcc.Init(); - - /* Deduce the data length required for 1 KWS inference from the network parameters. */ - auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride + - (kwsFrameLength - kwsFrameStride); - auto kwsMfccWindowSize = kwsFrameLength; - auto kwsMfccWindowStride = kwsFrameStride; - - /* We are choosing to move by half the window size => for a 1 second window size, - * this means an overlap of 0.5 seconds. */ - auto kwsAudioDataStride = kwsAudioDataWindowSize / 2; - - info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize); - - /* Stride must be multiple of mfcc features window stride to re-use features. */ - if (0 != kwsAudioDataStride % kwsMfccWindowStride) { - kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride; - } - - auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride; + /* Get input shape for feature extraction. */ + TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); + const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; - auto currentIndex = ctx.Get<uint32_t>("clipIndex"); + /* Set up pre and post-processing. */ + KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames, + kwsMfccFrameLength, kwsMfccFrameStride); - /* Creating a mfcc features sliding window for the data required for 1 inference. */ - auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - kwsAudioDataWindowSize, kwsMfccWindowSize, - kwsMfccWindowStride); + std::vector<ClassificationResult> singleInfResult; + KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"), + ctx.Get<std::vector<std::string>&>("kwsLabels"), + singleInfResult); /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow<const int16_t>( get_audio_array(currentIndex), get_audio_array_size(currentIndex), - kwsAudioDataWindowSize, kwsAudioDataStride); - - /* Calculate number of the feature vectors in the window overlap region. - * These feature vectors will be reused.*/ - size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1 - - kwsMfccVectorsInAudioStride; + preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); - auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor, - numberOfReusedFeatureVectors); - - if (!kwsMfccFeatureCalc){ - return output; - } - - /* Container for KWS results. */ - std::vector<arm::app::kws::KwsResult> kwsResults; + /* Declare a container to hold kws results from across the whole audio clip. */ + std::vector<kws::KwsResult> finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running KWS inference... "}; @@ -197,70 +142,56 @@ namespace app { while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - /* We moved to the next window - set the features sliding to the new address. */ - kwsAudioMFCCWindowSlider.Reset(inferenceWindow); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (kwsAudioMFCCWindowSlider.HasNext()) { - const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next(); - std::vector<int16_t> kwsMfccAudioData = - std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize); - - /* Compute features for this window and write them to input tensor. */ - kwsMfccFeatureCalc(kwsMfccAudioData, - kwsAudioMFCCWindowSlider.Index(), - useCache, - kwsMfccVectorsInAudioStride); - } + preProcess.m_audioWindowIndex = audioDataSlider.Index(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, - audioDataSlider.TotalStrides() + 1); + /* Run the pre-processing, inference and post-processing. */ + if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + printf_err("KWS Pre-processing failed."); + return output; + } - /* Run inference over this audio clip sliding window. */ if (!RunInference(kwsModel, profiler)) { - printf_err("KWS inference failed\n"); + printf_err("KWS Inference failed."); return output; } - std::vector<ClassificationResult> kwsClassificationResult; - auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier"); + if (!postProcess.DoPostProcess()) { + printf_err("KWS Post-processing failed."); + return output; + } - kwsClassifier.GetClassificationResults( - kwsOutputTensor, kwsClassificationResult, - ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true); + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); - kwsResults.emplace_back( - kws::KwsResult( - kwsClassificationResult, - audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride, - audioDataSlider.Index(), kwsScoreThreshold) - ); + /* Add results from this window to our final results vector. */ + finalResults.emplace_back( + kws::KwsResult(singleInfResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride, + audioDataSlider.Index(), kwsScoreThreshold)); - /* Keyword detected. */ - if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) { - output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize; + /* Break out when trigger keyword is detected. */ + if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") + && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { + output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize; output.asrAudioSamples = get_audio_array_size(currentIndex) - (audioDataSlider.NextWindowStartIndex() - - kwsAudioDataStride + kwsAudioDataWindowSize); + preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize); break; } #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(kwsOutputTensor); + DumpTensor(kwsOutputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - if (!PresentInferenceResult(kwsResults)) { + if (!PresentInferenceResult(finalResults)) { return output; } @@ -271,41 +202,41 @@ namespace app { } /** - * @brief Performs the ASR pipeline. - * - * @param[in,out] ctx pointer to the application context object - * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin - * and how much data to process - * @return bool true if pipeline executed without failure - */ - static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { + * @brief Performs the ASR pipeline. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin + * and how much data to process. + * @return true if pipeline executed without failure. + **/ + static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) + { + auto& asrModel = ctx.Get<Model&>("asrModel"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength"); + auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride"); + auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold"); + auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - auto& profiler = ctx.Get<Profiler&>("profiler"); - hal_lcd_clear(COLOR_BLACK); - - /* Get model reference. */ - auto& asrModel = ctx.Get<Model&>("asrmodel"); if (!asrModel.IsInited()) { printf_err("ASR model has not been initialised\n"); return false; } - /* Get score threshold to be applied for the classifier (post-inference). */ - auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold"); + hal_lcd_clear(COLOR_BLACK); - /* Dimensions of the tensor should have been verified by the callee. */ + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); - const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - /* Populate ASR MFCC related parameters. */ - auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength"); - auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride"); + /* Get input shape. Dimensions of the tensor should have been verified by + * the callee. */ + TfLiteIntArray* inputShape = asrModel.GetInputShape(0); - /* Populate ASR inference context and inner lengths for input. */ - auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + + const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx]; const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); /* Make sure the input tensor supports the above context and inner lengths. */ @@ -316,18 +247,9 @@ namespace app { } /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) * - asrMfccParamsWinStride + (asrMfccParamsWinLen); - const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride; - const float asrAudioParamsSecondsPerSample = - (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); - - /* Get pre/post-processing objects */ - auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess"); - auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess"); - - /* Set default reduction axis for post-processing. */ - const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); + const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride; + const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; /* Get the remaining audio buffer and respective size from KWS results. */ const int16_t* audioArr = kwsOutput.asrAudioStart; @@ -335,9 +257,9 @@ namespace app { /* Audio clip must have enough samples to produce 1 MFCC feature. */ std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize); - if (audioArrSize < asrMfccParamsWinLen) { + if (audioArrSize < asrMfccFrameLen) { printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n", - asrMfccParamsWinLen); + asrMfccFrameLen); return false; } @@ -345,26 +267,38 @@ namespace app { auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( audioBuffer.data(), audioBuffer.size(), - asrAudioParamsWinLen, - asrAudioParamsWinStride); + asrAudioDataWindowLen, + asrAudioDataWindowStride); /* Declare a container for results. */ - std::vector<arm::app::asr::AsrResult> asrResults; + std::vector<asr::AsrResult> asrResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running ASR inference... "}; - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - size_t asrInferenceWindowLen = asrAudioParamsWinLen; - + size_t asrInferenceWindowLen = asrAudioDataWindowLen; + + /* Set up pre and post-processing objects. */ + AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + asrMfccFrameLen, asrMfccFrameStride); + + std::vector<ClassificationResult> singleInfResult; + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen); + AsrPostProcess asrPostProcess = AsrPostProcess( + asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"), + ctx.Get<std::vector<std::string>&>("asrLabels"), + singleInfResult, outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx + ); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { /* If not enough audio see how much can be sent for processing. */ size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); - if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) { + if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) { asrInferenceWindowLen = audioBuffer.size() - nextStartIndex; } @@ -373,8 +307,11 @@ namespace app { info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); - /* Calculate MFCCs, deltas and populate the input tensor. */ - asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor); + /* Run the pre-processing, inference and post-processing. */ + if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) { + printf_err("ASR pre-processing failed."); + return false; + } /* Run inference over this audio clip sliding window. */ if (!RunInference(asrModel, profiler)) { @@ -382,24 +319,28 @@ namespace app { return false; } - /* Post-process. */ - asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext()); + /* Post processing needs to know if we are on the last audio window. */ + asrPostProcess.m_lastIteration = !audioDataSlider.HasNext(); + if (!asrPostProcess.DoPostProcess()) { + printf_err("ASR post-processing failed."); + return false; + } /* Get results. */ std::vector<ClassificationResult> asrClassificationResult; - auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier"); + auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier"); asrClassifier.GetClassificationResults( asrOutputTensor, asrClassificationResult, - ctx.Get<std::vector<std::string>&>("asrlabels"), 1); + ctx.Get<std::vector<std::string>&>("asrLabels"), 1); asrResults.emplace_back(asr::AsrResult(asrClassificationResult, (audioDataSlider.Index() * asrAudioParamsSecondsPerSample * - asrAudioParamsWinStride), + asrAudioDataWindowStride), audioDataSlider.Index(), asrScoreThreshold)); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); + armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ /* Erase */ @@ -417,7 +358,7 @@ namespace app { return true; } - /* Audio inference classification handler. */ + /* KWS and ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { hal_lcd_clear(COLOR_BLACK); @@ -434,13 +375,14 @@ namespace app { do { KWSOutput kwsOutput = doKws(ctx); if (!kwsOutput.executionSuccess) { + printf_err("KWS failed\n"); return false; } if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { - info("Keyword spotted\n"); + info("Trigger keyword spotted\n"); if(!doAsr(ctx, kwsOutput)) { - printf_err("ASR failed"); + printf_err("ASR failed\n"); return false; } } @@ -452,7 +394,6 @@ namespace app { return true; } - static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -464,33 +405,31 @@ namespace app { /* Display each result. */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (uint32_t i = 0; i < results.size(); ++i) { - + for (auto & result : results) { std::string topKeyword{"<none>"}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { - topKeyword = results[i].m_resultVec[0].m_label; - score = results[i].m_resultVec[0].m_normalisedVal; + if (!result.m_resultVec.empty()) { + topKeyword = result.m_resultVec[0].m_label; + score = result.m_resultVec[0].m_normalisedVal; } std::string resultStr = - std::string{"@"} + std::to_string(results[i].m_timeStamp) + + std::string{"@"} + std::to_string(result.m_timeStamp) + std::string{"s: "} + topKeyword + std::string{" ("} + std::to_string(static_cast<int>(score * 100)) + std::string{"%)"}; - hal_lcd_display_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); + hal_lcd_display_text(resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); rowIdx1 += dataPsnTxtYIncr; info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n", - results[i].m_timeStamp, results[i].m_inferenceNumber, - results[i].m_threshold); - for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) { + result.m_timeStamp, result.m_inferenceNumber, + result.m_threshold); + for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j, - results[i].m_resultVec[j].m_label.c_str(), - results[i].m_resultVec[j].m_normalisedVal); + result.m_resultVec[j].m_label.c_str(), + result.m_resultVec[j].m_normalisedVal); } } @@ -523,143 +462,12 @@ namespace app { std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text( - finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); info("Final result: %s\n", finalResultStr.c_str()); return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - **/ - template<class T> - std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute) - { - /* Feature cache to be captured by lambda function. */ - static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); - - return [=](std::vector<int16_t>& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex) - { - T* tensorData = tflite::GetTensorData<T>(inputTensor); - std::vector<T> features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. - */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size(); - auto sizeBytes = sizeof(T) * size; - std::memcpy(tensorData + (index * size), features.data(), sizeBytes); - - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex) { - featureCache[index - featuresOverlapIndex] = std::move(features); - } - }; - } - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<int8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<uint8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<int16_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute); - - template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)> - FeatureCalc<float>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<float>(std::vector<int16_t>&)> compute); - - - static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) - { - std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto* quantParams = (TfLiteAffineQuantization*) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<int8_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteUInt8: { - mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteInt16: { - mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<int16_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - - - } else { - mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor, - cacheSize, - [&mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccCompute(audioDataWindow); - }); - } - return mfccFeatureCalc; - } } /* namespace app */ } /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc index 2a76b1b..42f434e 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,62 +15,71 @@ * limitations under the License. */ #include "Wav2LetterPostprocess.hpp" + #include "Wav2LetterModel.hpp" #include "log_macros.h" +#include <cmath> + namespace arm { namespace app { -namespace audio { -namespace asr { - - Postprocess::Postprocess(const uint32_t contextLen, - const uint32_t innerLen, - const uint32_t blankTokenIdx) - : m_contextLen(contextLen), - m_innerLen(innerLen), - m_totalLen(2 * this->m_contextLen + this->m_innerLen), + + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector<std::string>& labels, std::vector<ClassificationResult>& results, + const uint32_t outputContextLen, + const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx + ): + m_classifier(classifier), + m_outputTensor(outputTensor), + m_labels{labels}, + m_results(results), + m_outputContextLen(outputContextLen), m_countIterations(0), - m_blankTokenIdx(blankTokenIdx) - {} + m_blankTokenIdx(blankTokenIdx), + m_reductionAxisIdx(reductionAxisIdx) + { + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); + } - bool Postprocess::Invoke(TfLiteTensor* tensor, - const uint32_t axisIdx, - const bool lastIteration) + bool AsrPostProcess::DoPostProcess() { /* Basic checks. */ - if (!this->IsInputValid(tensor, axisIdx)) { + if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { return false; } /* Irrespective of tensor type, we use unsigned "byte" */ - uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor); - const uint32_t elemSz = this->GetTensorElementSize(tensor); + auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); /* Other sanity checks. */ if (0 == elemSz) { printf_err("Tensor type not supported for post processing\n"); return false; - } else if (elemSz * this->m_totalLen > tensor->bytes) { + } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) { printf_err("Insufficient number of tensor bytes\n"); return false; } /* Which axis do we need to process? */ - switch (axisIdx) { - case arm::app::Wav2LetterModel::ms_outputRowsIdx: - return this->EraseSectionsRowWise(ptrData, - elemSz * - tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx], - lastIteration); + switch (this->m_reductionAxisIdx) { + case Wav2LetterModel::ms_outputRowsIdx: + this->EraseSectionsRowWise( + ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx], + this->m_lastIteration); + break; default: - printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx); + printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx); + return false; } + this->m_classifier.GetClassificationResults(this->m_outputTensor, + this->m_results, this->m_labels, 1); - return false; + return true; } - bool Postprocess::IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const { if (nullptr == tensor) { return false; @@ -84,25 +93,23 @@ namespace asr { if (static_cast<int>(this->m_totalLen) != tensor->dims->data[axisIdx]) { - printf_err("Unexpected tensor dimension for axis %d, \n", - tensor->dims->data[axisIdx]); + printf_err("Unexpected tensor dimension for axis %d, got %d, \n", + axisIdx, tensor->dims->data[axisIdx]); return false; } return true; } - uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor) + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) { switch(tensor->type) { case kTfLiteUInt8: - return 1; case kTfLiteInt8: return 1; case kTfLiteInt16: return 2; case kTfLiteInt32: - return 4; case kTfLiteFloat32: return 4; default: @@ -113,30 +120,30 @@ namespace asr { return 0; } - bool Postprocess::EraseSectionsRowWise( - uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration) + bool AsrPostProcess::EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) { /* In this case, the "zero-ing" is quite simple as the region * to be zeroed sits in contiguous memory (row-major). */ - const uint32_t eraseLen = strideSzBytes * this->m_contextLen; + const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen; /* Erase left context? */ if (this->m_countIterations > 0) { /* Set output of each classification window to the blank token. */ std::memset(ptrData, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } /* Erase right context? */ if (false == lastIteration) { - uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen)); + uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen)); /* Set output of each classification window to the blank token. */ std::memset(rightCtxPtr, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } @@ -150,7 +157,58 @@ namespace asr { return true; } -} /* namespace asr */ -} /* namespace audio */ + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) + { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); + if (inputRows == 0) { + printf_err("Error getting number of input rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_inputRowsIdx); + } + return inputRows; + } + + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + { + const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + } + + /* Watching for underflow. */ + int innerLen = (outputRows - (2 * outputCtxLen)); + + return std::max(innerLen, 0); + } + + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + { + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + return 0; + } + + const float inOutRowRatio = static_cast<float>(inputRows) / + static_cast<float>(outputRows); + + return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio); + } + } /* namespace app */ } /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc index d3f3579..92b0631 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,41 +20,35 @@ #include "TensorFlowLiteMicro.hpp" #include <algorithm> -#include <math.h> +#include <cmath> namespace arm { namespace app { -namespace audio { -namespace asr { - - Preprocess::Preprocess( - const uint32_t numMfccFeatures, - const uint32_t windowLen, - const uint32_t windowStride, - const uint32_t numMfccVectors): - m_mfcc(numMfccFeatures, windowLen), - m_mfccBuf(numMfccFeatures, numMfccVectors), - m_delta1Buf(numMfccFeatures, numMfccVectors), - m_delta2Buf(numMfccFeatures, numMfccVectors), - m_windowLen(windowLen), - m_windowStride(windowStride), + + AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride + ): + m_mfcc(numMfccFeatures, mfccWindowLen), + m_inputTensor(inputTensor), + m_mfccBuf(numMfccFeatures, numFeatureFrames), + m_delta1Buf(numMfccFeatures, numFeatureFrames), + m_delta2Buf(numMfccFeatures, numFeatureFrames), + m_mfccWindowLen(mfccWindowLen), + m_mfccWindowStride(mfccWindowStride), m_numMfccFeats(numMfccFeatures), - m_numFeatVectors(numMfccVectors), - m_window() + m_numFeatureFrames(numFeatureFrames) { - if (numMfccFeatures > 0 && windowLen > 0) { + if (numMfccFeatures > 0 && mfccWindowLen > 0) { this->m_mfcc.Init(); } } - bool Preprocess::Invoke( - const int16_t* audioData, - const uint32_t audioDataLen, - TfLiteTensor* tensor) + bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) { - this->m_window = SlidingWindow<const int16_t>( - audioData, audioDataLen, - this->m_windowLen, this->m_windowStride); + this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>( + static_cast<const int16_t*>(audioData), audioDataLen, + this->m_mfccWindowLen, this->m_mfccWindowStride); uint32_t mfccBufIdx = 0; @@ -62,12 +56,12 @@ namespace asr { std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f); std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f); - /* While we can slide over the window. */ - while (this->m_window.HasNext()) { - const int16_t* mfccWindow = this->m_window.Next(); + /* While we can slide over the audio. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); auto mfccAudioData = std::vector<int16_t>( mfccWindow, - mfccWindow + this->m_windowLen); + mfccWindow + this->m_mfccWindowLen); auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData); for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) { this->m_mfccBuf(i, mfccBufIdx) = mfcc[i]; @@ -76,11 +70,11 @@ namespace asr { } /* Pad MFCC if needed by adding MFCC for zeros. */ - if (mfccBufIdx != this->m_numFeatVectors) { - std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_windowLen, 0); + if (mfccBufIdx != this->m_numFeatureFrames) { + std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_mfccWindowLen, 0); std::vector<float> mfccZeros = this->m_mfcc.MfccCompute(zerosWindow); - while (mfccBufIdx != this->m_numFeatVectors) { + while (mfccBufIdx != this->m_numFeatureFrames) { memcpy(&this->m_mfccBuf(0, mfccBufIdx), mfccZeros.data(), sizeof(float) * m_numMfccFeats); ++mfccBufIdx; @@ -88,41 +82,39 @@ namespace asr { } /* Compute first and second order deltas from MFCCs. */ - this->ComputeDeltas(this->m_mfccBuf, - this->m_delta1Buf, - this->m_delta2Buf); + AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); - /* Normalise. */ - this->Normalise(); + /* Standardize calculated features. */ + this->Standarize(); /* Quantise. */ - QuantParams quantParams = GetTensorQuantParams(tensor); + QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor); if (0 == quantParams.scale) { printf_err("Quantisation scale can't be 0\n"); return false; } - switch(tensor->type) { + switch(this->m_inputTensor->type) { case kTfLiteUInt8: return this->Quantise<uint8_t>( - tflite::GetTensorData<uint8_t>(tensor), tensor->bytes, + tflite::GetTensorData<uint8_t>(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); case kTfLiteInt8: return this->Quantise<int8_t>( - tflite::GetTensorData<int8_t>(tensor), tensor->bytes, + tflite::GetTensorData<int8_t>(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); default: printf_err("Unsupported tensor type %s\n", - TfLiteTypeGetName(tensor->type)); + TfLiteTypeGetName(this->m_inputTensor->type)); } return false; } - bool Preprocess::ComputeDeltas(Array2d<float>& mfcc, - Array2d<float>& delta1, - Array2d<float>& delta2) + bool AsrPreProcess::ComputeDeltas(Array2d<float>& mfcc, + Array2d<float>& delta1, + Array2d<float>& delta2) { const std::vector <float> delta1Coeffs = {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, @@ -148,11 +140,11 @@ namespace asr { /* Iterate through features in MFCC vector. */ for (size_t i = 0; i < numFeatures; ++i) { /* For each feature, iterate through time (t) samples representing feature evolution and - * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels. + * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels. * Convolution padding = valid, result size is `time length - kernel length + 1`. * The result is padded with 0 from both sides to match the size of initial time samples data. * - * For the small filter, conv1d implementation as a simple loop is efficient enough. + * For the small filter, conv1D implementation as a simple loop is efficient enough. * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32. */ @@ -175,20 +167,10 @@ namespace asr { return true; } - float Preprocess::GetMean(Array2d<float>& vec) - { - return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); - } - - float Preprocess::GetStdDev(Array2d<float>& vec, const float mean) - { - return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); - } - - void Preprocess::NormaliseVec(Array2d<float>& vec) + void AsrPreProcess::StandardizeVecF32(Array2d<float>& vec) { - auto mean = Preprocess::GetMean(vec); - auto stddev = Preprocess::GetStdDev(vec, mean); + auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); debug("Mean: %f, Stddev: %f\n", mean, stddev); if (stddev == 0) { @@ -204,14 +186,14 @@ namespace asr { } } - void Preprocess::Normalise() + void AsrPreProcess::Standarize() { - Preprocess::NormaliseVec(this->m_mfccBuf); - Preprocess::NormaliseVec(this->m_delta1Buf); - Preprocess::NormaliseVec(this->m_delta2Buf); + AsrPreProcess::StandardizeVecF32(this->m_mfccBuf); + AsrPreProcess::StandardizeVecF32(this->m_delta1Buf); + AsrPreProcess::StandardizeVecF32(this->m_delta2Buf); } - float Preprocess::GetQuantElem( + float AsrPreProcess::GetQuantElem( const float elem, const float quantScale, const int quantOffset, @@ -222,7 +204,5 @@ namespace asr { return std::min<float>(std::max<float>(val, minVal), maxVal); } -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */
\ No newline at end of file |