diff options
Diffstat (limited to 'source/use_case/kws_asr/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/kws_asr/src/UseCaseHandler.cc | 492 |
1 files changed, 150 insertions, 342 deletions
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 1e1a400..01aefae 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -28,6 +28,7 @@ #include "Wav2LetterMfcc.hpp" #include "Wav2LetterPreprocess.hpp" #include "Wav2LetterPostprocess.hpp" +#include "KwsProcessing.hpp" #include "AsrResult.hpp" #include "AsrClassifier.hpp" #include "OutputDecode.hpp" @@ -39,11 +40,6 @@ using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { - enum AsrOutputReductionAxis { - AxisRow = 1, - AxisCol = 2 - }; - struct KWSOutput { bool executionSuccess = false; const int16_t* asrAudioStart = nullptr; @@ -51,73 +47,53 @@ namespace app { }; /** - * @brief Presents kws inference results using the data presentation - * object. - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents KWS inference results. + * @param[in] results Vector of KWS classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results); + static bool PresentInferenceResult(std::vector<kws::KwsResult>& results); /** - * @brief Presents asr inference results using the data presentation - * object. - * @param[in] platform reference to the hal platform object - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results); + static bool PresentInferenceResult(std::vector<asr::AsrResult>& results); /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MFCC data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] mfcc MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). - * - * @return function function to be called providing audio sample and sliding window index. + * @brief Performs the KWS pipeline. + * @param[in,out] ctx pointer to the application context object + * @return struct containing pointer to audio data where ASR should begin + * and how much data to process. **/ - static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, - TfLiteTensor* inputTensor, - size_t cacheSize); + static KWSOutput doKws(ApplicationContext& ctx) + { + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& kwsModel = ctx.Get<Model&>("kwsModel"); + const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength"); + const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride"); + const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold"); + + auto currentIndex = ctx.Get<uint32_t>("clipIndex"); - /** - * @brief Performs the KWS pipeline. - * @param[in,out] ctx pointer to the application context object - * - * @return KWSOutput struct containing pointer to audio data where ASR should begin - * and how much data to process. - */ - static KWSOutput doKws(ApplicationContext& ctx) { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? - arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); + (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? + MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); - KWSOutput output; + /* Output struct from doing KWS. */ + KWSOutput output {}; - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& kwsModel = ctx.Get<Model&>("kwsmodel"); if (!kwsModel.IsInited()) { printf_err("KWS model has not been initialised\n"); return output; } - const int kwsFrameLength = ctx.Get<int>("kwsframeLength"); - const int kwsFrameStride = ctx.Get<int>("kwsframeStride"); - const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold"); - - TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); - + TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); if (!kwsInputTensor->dims) { printf_err("Invalid input tensor dims\n"); return output; @@ -126,63 +102,32 @@ namespace app { return output; } - const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc"); - const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins"); - - audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength); - kwsMfcc.Init(); - - /* Deduce the data length required for 1 KWS inference from the network parameters. */ - auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride + - (kwsFrameLength - kwsFrameStride); - auto kwsMfccWindowSize = kwsFrameLength; - auto kwsMfccWindowStride = kwsFrameStride; - - /* We are choosing to move by half the window size => for a 1 second window size, - * this means an overlap of 0.5 seconds. */ - auto kwsAudioDataStride = kwsAudioDataWindowSize / 2; - - info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize); - - /* Stride must be multiple of mfcc features window stride to re-use features. */ - if (0 != kwsAudioDataStride % kwsMfccWindowStride) { - kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride; - } - - auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride; + /* Get input shape for feature extraction. */ + TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); + const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; - auto currentIndex = ctx.Get<uint32_t>("clipIndex"); + /* Set up pre and post-processing. */ + KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames, + kwsMfccFrameLength, kwsMfccFrameStride); - /* Creating a mfcc features sliding window for the data required for 1 inference. */ - auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - kwsAudioDataWindowSize, kwsMfccWindowSize, - kwsMfccWindowStride); + std::vector<ClassificationResult> singleInfResult; + KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"), + ctx.Get<std::vector<std::string>&>("kwsLabels"), + singleInfResult); /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow<const int16_t>( get_audio_array(currentIndex), get_audio_array_size(currentIndex), - kwsAudioDataWindowSize, kwsAudioDataStride); - - /* Calculate number of the feature vectors in the window overlap region. - * These feature vectors will be reused.*/ - size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1 - - kwsMfccVectorsInAudioStride; + preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); - auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor, - numberOfReusedFeatureVectors); - - if (!kwsMfccFeatureCalc){ - return output; - } - - /* Container for KWS results. */ - std::vector<arm::app::kws::KwsResult> kwsResults; + /* Declare a container to hold kws results from across the whole audio clip. */ + std::vector<kws::KwsResult> finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running KWS inference... "}; @@ -197,70 +142,56 @@ namespace app { while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - /* We moved to the next window - set the features sliding to the new address. */ - kwsAudioMFCCWindowSlider.Reset(inferenceWindow); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (kwsAudioMFCCWindowSlider.HasNext()) { - const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next(); - std::vector<int16_t> kwsMfccAudioData = - std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize); - - /* Compute features for this window and write them to input tensor. */ - kwsMfccFeatureCalc(kwsMfccAudioData, - kwsAudioMFCCWindowSlider.Index(), - useCache, - kwsMfccVectorsInAudioStride); - } + preProcess.m_audioWindowIndex = audioDataSlider.Index(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, - audioDataSlider.TotalStrides() + 1); + /* Run the pre-processing, inference and post-processing. */ + if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + printf_err("KWS Pre-processing failed."); + return output; + } - /* Run inference over this audio clip sliding window. */ if (!RunInference(kwsModel, profiler)) { - printf_err("KWS inference failed\n"); + printf_err("KWS Inference failed."); return output; } - std::vector<ClassificationResult> kwsClassificationResult; - auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier"); + if (!postProcess.DoPostProcess()) { + printf_err("KWS Post-processing failed."); + return output; + } - kwsClassifier.GetClassificationResults( - kwsOutputTensor, kwsClassificationResult, - ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true); + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); - kwsResults.emplace_back( - kws::KwsResult( - kwsClassificationResult, - audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride, - audioDataSlider.Index(), kwsScoreThreshold) - ); + /* Add results from this window to our final results vector. */ + finalResults.emplace_back( + kws::KwsResult(singleInfResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride, + audioDataSlider.Index(), kwsScoreThreshold)); - /* Keyword detected. */ - if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) { - output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize; + /* Break out when trigger keyword is detected. */ + if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") + && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { + output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize; output.asrAudioSamples = get_audio_array_size(currentIndex) - (audioDataSlider.NextWindowStartIndex() - - kwsAudioDataStride + kwsAudioDataWindowSize); + preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize); break; } #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(kwsOutputTensor); + DumpTensor(kwsOutputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - if (!PresentInferenceResult(kwsResults)) { + if (!PresentInferenceResult(finalResults)) { return output; } @@ -271,41 +202,41 @@ namespace app { } /** - * @brief Performs the ASR pipeline. - * - * @param[in,out] ctx pointer to the application context object - * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin - * and how much data to process - * @return bool true if pipeline executed without failure - */ - static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { + * @brief Performs the ASR pipeline. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin + * and how much data to process. + * @return true if pipeline executed without failure. + **/ + static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) + { + auto& asrModel = ctx.Get<Model&>("asrModel"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength"); + auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride"); + auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold"); + auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - auto& profiler = ctx.Get<Profiler&>("profiler"); - hal_lcd_clear(COLOR_BLACK); - - /* Get model reference. */ - auto& asrModel = ctx.Get<Model&>("asrmodel"); if (!asrModel.IsInited()) { printf_err("ASR model has not been initialised\n"); return false; } - /* Get score threshold to be applied for the classifier (post-inference). */ - auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold"); + hal_lcd_clear(COLOR_BLACK); - /* Dimensions of the tensor should have been verified by the callee. */ + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); - const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - /* Populate ASR MFCC related parameters. */ - auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength"); - auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride"); + /* Get input shape. Dimensions of the tensor should have been verified by + * the callee. */ + TfLiteIntArray* inputShape = asrModel.GetInputShape(0); - /* Populate ASR inference context and inner lengths for input. */ - auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + + const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx]; const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); /* Make sure the input tensor supports the above context and inner lengths. */ @@ -316,18 +247,9 @@ namespace app { } /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) * - asrMfccParamsWinStride + (asrMfccParamsWinLen); - const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride; - const float asrAudioParamsSecondsPerSample = - (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); - - /* Get pre/post-processing objects */ - auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess"); - auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess"); - - /* Set default reduction axis for post-processing. */ - const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); + const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride; + const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; /* Get the remaining audio buffer and respective size from KWS results. */ const int16_t* audioArr = kwsOutput.asrAudioStart; @@ -335,9 +257,9 @@ namespace app { /* Audio clip must have enough samples to produce 1 MFCC feature. */ std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize); - if (audioArrSize < asrMfccParamsWinLen) { + if (audioArrSize < asrMfccFrameLen) { printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n", - asrMfccParamsWinLen); + asrMfccFrameLen); return false; } @@ -345,26 +267,38 @@ namespace app { auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( audioBuffer.data(), audioBuffer.size(), - asrAudioParamsWinLen, - asrAudioParamsWinStride); + asrAudioDataWindowLen, + asrAudioDataWindowStride); /* Declare a container for results. */ - std::vector<arm::app::asr::AsrResult> asrResults; + std::vector<asr::AsrResult> asrResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running ASR inference... "}; - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - size_t asrInferenceWindowLen = asrAudioParamsWinLen; - + size_t asrInferenceWindowLen = asrAudioDataWindowLen; + + /* Set up pre and post-processing objects. */ + AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + asrMfccFrameLen, asrMfccFrameStride); + + std::vector<ClassificationResult> singleInfResult; + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen); + AsrPostProcess asrPostProcess = AsrPostProcess( + asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"), + ctx.Get<std::vector<std::string>&>("asrLabels"), + singleInfResult, outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx + ); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { /* If not enough audio see how much can be sent for processing. */ size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); - if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) { + if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) { asrInferenceWindowLen = audioBuffer.size() - nextStartIndex; } @@ -373,8 +307,11 @@ namespace app { info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); - /* Calculate MFCCs, deltas and populate the input tensor. */ - asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor); + /* Run the pre-processing, inference and post-processing. */ + if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) { + printf_err("ASR pre-processing failed."); + return false; + } /* Run inference over this audio clip sliding window. */ if (!RunInference(asrModel, profiler)) { @@ -382,24 +319,28 @@ namespace app { return false; } - /* Post-process. */ - asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext()); + /* Post processing needs to know if we are on the last audio window. */ + asrPostProcess.m_lastIteration = !audioDataSlider.HasNext(); + if (!asrPostProcess.DoPostProcess()) { + printf_err("ASR post-processing failed."); + return false; + } /* Get results. */ std::vector<ClassificationResult> asrClassificationResult; - auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier"); + auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier"); asrClassifier.GetClassificationResults( asrOutputTensor, asrClassificationResult, - ctx.Get<std::vector<std::string>&>("asrlabels"), 1); + ctx.Get<std::vector<std::string>&>("asrLabels"), 1); asrResults.emplace_back(asr::AsrResult(asrClassificationResult, (audioDataSlider.Index() * asrAudioParamsSecondsPerSample * - asrAudioParamsWinStride), + asrAudioDataWindowStride), audioDataSlider.Index(), asrScoreThreshold)); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); + armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ /* Erase */ @@ -417,7 +358,7 @@ namespace app { return true; } - /* Audio inference classification handler. */ + /* KWS and ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { hal_lcd_clear(COLOR_BLACK); @@ -434,13 +375,14 @@ namespace app { do { KWSOutput kwsOutput = doKws(ctx); if (!kwsOutput.executionSuccess) { + printf_err("KWS failed\n"); return false; } if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { - info("Keyword spotted\n"); + info("Trigger keyword spotted\n"); if(!doAsr(ctx, kwsOutput)) { - printf_err("ASR failed"); + printf_err("ASR failed\n"); return false; } } @@ -452,7 +394,6 @@ namespace app { return true; } - static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -464,33 +405,31 @@ namespace app { /* Display each result. */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (uint32_t i = 0; i < results.size(); ++i) { - + for (auto & result : results) { std::string topKeyword{"<none>"}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { - topKeyword = results[i].m_resultVec[0].m_label; - score = results[i].m_resultVec[0].m_normalisedVal; + if (!result.m_resultVec.empty()) { + topKeyword = result.m_resultVec[0].m_label; + score = result.m_resultVec[0].m_normalisedVal; } std::string resultStr = - std::string{"@"} + std::to_string(results[i].m_timeStamp) + + std::string{"@"} + std::to_string(result.m_timeStamp) + std::string{"s: "} + topKeyword + std::string{" ("} + std::to_string(static_cast<int>(score * 100)) + std::string{"%)"}; - hal_lcd_display_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); + hal_lcd_display_text(resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); rowIdx1 += dataPsnTxtYIncr; info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n", - results[i].m_timeStamp, results[i].m_inferenceNumber, - results[i].m_threshold); - for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) { + result.m_timeStamp, result.m_inferenceNumber, + result.m_threshold); + for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j, - results[i].m_resultVec[j].m_label.c_str(), - results[i].m_resultVec[j].m_normalisedVal); + result.m_resultVec[j].m_label.c_str(), + result.m_resultVec[j].m_normalisedVal); } } @@ -523,143 +462,12 @@ namespace app { std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text( - finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); info("Final result: %s\n", finalResultStr.c_str()); return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - **/ - template<class T> - std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute) - { - /* Feature cache to be captured by lambda function. */ - static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); - - return [=](std::vector<int16_t>& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex) - { - T* tensorData = tflite::GetTensorData<T>(inputTensor); - std::vector<T> features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. - */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size(); - auto sizeBytes = sizeof(T) * size; - std::memcpy(tensorData + (index * size), features.data(), sizeBytes); - - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex) { - featureCache[index - featuresOverlapIndex] = std::move(features); - } - }; - } - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<int8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<uint8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<int16_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute); - - template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)> - FeatureCalc<float>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<float>(std::vector<int16_t>&)> compute); - - - static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) - { - std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto* quantParams = (TfLiteAffineQuantization*) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<int8_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteUInt8: { - mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteInt16: { - mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<int16_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - - - } else { - mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor, - cacheSize, - [&mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccCompute(audioDataWindow); - }); - } - return mfccFeatureCalc; - } } /* namespace app */ } /* namespace arm */
\ No newline at end of file |