diff options
Diffstat (limited to 'source/use_case/asr/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/asr/src/UseCaseHandler.cc | 166 |
1 files changed, 75 insertions, 91 deletions
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 420f725..7fe959b 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -20,7 +20,6 @@ #include "AsrClassifier.hpp" #include "Wav2LetterModel.hpp" #include "hal.h" -#include "Wav2LetterMfcc.hpp" #include "AudioUtils.hpp" #include "ImageUtils.hpp" #include "UseCaseCommonUtils.hpp" @@ -34,68 +33,63 @@ namespace arm { namespace app { /** - * @brief Presents inference results using the data presentation - * object. - * @param[in] results Vector of classification results to be displayed. + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results); + static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results); - /* Audio inference classification handler. */ + /* ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - constexpr uint32_t dataPsnTxtInfStartX = 20; - constexpr uint32_t dataPsnTxtInfStartY = 40; - - hal_lcd_clear(COLOR_BLACK); - + auto& model = ctx.Get<Model&>("model"); auto& profiler = ctx.Get<Profiler&>("profiler"); - + auto mfccFrameLen = ctx.Get<uint32_t>("frameLength"); + auto mfccFrameStride = ctx.Get<uint32_t>("frameStride"); + auto scoreThreshold = ctx.Get<float>("scoreThreshold"); + auto inputCtxLen = ctx.Get<uint32_t>("ctxLen"); /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { return false; } } + auto initialClipIdx = ctx.Get<uint32_t>("clipIndex"); + constexpr uint32_t dataPsnTxtInfStartX = 20; + constexpr uint32_t dataPsnTxtInfStartY = 40; - /* Get model reference. */ - auto& model = ctx.Get<Model&>("model"); if (!model.IsInited()) { printf_err("Model is not initialised! Terminating processing.\n"); return false; } - /* Get score threshold to be applied for the classifier (post-inference). */ - auto scoreThreshold = ctx.Get<float>("scoreThreshold"); - - /* Get tensors. Dimensions of the tensor should have been verified by + /* Get input shape. Dimensions of the tensor should have been verified by * the callee. */ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; + TfLiteIntArray* inputShape = model.GetInputShape(0); - /* Populate MFCC related parameters. */ - auto mfccParamsWinLen = ctx.Get<uint32_t>("frameLength"); - auto mfccParamsWinStride = ctx.Get<uint32_t>("frameStride"); - - /* Populate ASR inference context and inner lengths for input. */ - auto inputCtxLen = ctx.Get<uint32_t>("ctxLen"); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + const uint32_t inputRowsSize = inputShape->data[Wav2LetterModel::ms_inputRowsIdx]; + const uint32_t inputInnerLen = inputRowsSize - (2 * inputCtxLen); /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen); - const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride; - const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); + const uint32_t audioDataWindowLen = (inputRowsSize - 1) * mfccFrameStride + (mfccFrameLen); + const uint32_t audioDataWindowStride = inputInnerLen * mfccFrameStride; + + /* NOTE: This is only used for time stamp calculation. */ + const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq); - /* Get pre/post-processing objects. */ - auto& prep = ctx.Get<audio::asr::Preprocess&>("preprocess"); - auto& postp = ctx.Get<audio::asr::Postprocess&>("postprocess"); + /* Set up pre and post-processing objects. */ + ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride); - /* Set default reduction axis for post-processing. */ - const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + std::vector<ClassificationResult> singleInfResult; + const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen); + ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"), + model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"), + singleInfResult, outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx + ); - /* Audio clip start index. */ - auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); + UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model); /* Loop to process audio clips. */ do { @@ -109,44 +103,41 @@ namespace app { const uint32_t audioArrSize = get_audio_array_size(currentIndex); if (!audioArr) { - printf_err("Invalid audio array pointer\n"); + printf_err("Invalid audio array pointer.\n"); return false; } - /* Audio clip must have enough samples to produce 1 MFCC feature. */ - if (audioArrSize < mfccParamsWinLen) { + /* Audio clip needs enough samples to produce at least 1 MFCC feature. */ + if (audioArrSize < mfccFrameLen) { printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n", - mfccParamsWinLen); + mfccFrameLen); return false; } - /* Initialise an audio slider. */ + /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( - audioArr, - audioArrSize, - audioParamsWinLen, - audioParamsWinStride); + audioArr, audioArrSize, + audioDataWindowLen, audioDataWindowStride); - /* Declare a container for results. */ - std::vector<arm::app::asr::AsrResult> results; + /* Declare a container for final results. */ + std::vector<asr::AsrResult> finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex)); - size_t inferenceWindowLen = audioParamsWinLen; + size_t inferenceWindowLen = audioDataWindowLen; /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { - /* If not enough audio see how much can be sent for processing. */ + /* If not enough audio, see how much can be sent for processing. */ size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); - if (nextStartIndex + audioParamsWinLen > audioArrSize) { + if (nextStartIndex + audioDataWindowLen > audioArrSize) { inferenceWindowLen = audioArrSize - nextStartIndex; } @@ -155,46 +146,40 @@ namespace app { info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); - /* Calculate MFCCs, deltas and populate the input tensor. */ - prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor); + /* Run the pre-processing, inference and post-processing. */ + runner.PreProcess(inferenceWindow, inferenceWindowLen); - /* Run inference over this audio clip sliding window. */ - if (!RunInference(model, profiler)) { + profiler.StartProfiling("Inference"); + if (!runner.RunInference()) { return false; } + profiler.StopProfiling(); - /* Post-process. */ - postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext()); - - /* Get results. */ - std::vector<ClassificationResult> classificationResult; - auto& classifier = ctx.Get<AsrClassifier&>("classifier"); - classifier.GetClassificationResults( - outputTensor, classificationResult, - ctx.Get<std::vector<std::string>&>("labels"), 1); + postProcess.m_lastIteration = !audioDataSlider.HasNext(); + if (!runner.PostProcess()) { + return false; + } - results.emplace_back(asr::AsrResult(classificationResult, - (audioDataSlider.Index() * - audioParamsSecondsPerSample * - audioParamsWinStride), - audioDataSlider.Index(), scoreThreshold)); + /* Add results from this window to our final results vector. */ + finalResults.emplace_back(asr::AsrResult(singleInfResult, + (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride), + audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(outputTensor, - outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + armDumpTensor(outputTensor, + outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ - - } + } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - ctx.Set<std::vector<arm::app::asr::AsrResult>>("results", results); + ctx.Set<std::vector<asr::AsrResult>>("results", finalResults); - if (!PresentInferenceResult(results)) { + if (!PresentInferenceResult(finalResults)) { return false; } @@ -202,13 +187,13 @@ namespace app { IncrementAppCtxIfmIdx(ctx,"clipIndex"); - } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); + } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx); return true; } - static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results) + static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 60; @@ -219,15 +204,15 @@ namespace app { info("Final results:\n"); info("Total number of inferences: %zu\n", results.size()); /* Results from multiple inferences should be combined before processing. */ - std::vector<arm::app::ClassificationResult> combinedResults; - for (auto& result : results) { + std::vector<ClassificationResult> combinedResults; + for (const auto& result : results) { combinedResults.insert(combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end()); } /* Get each inference result string using the decoder. */ - for (const auto & result : results) { + for (const auto& result : results) { std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n", @@ -238,10 +223,9 @@ namespace app { /* Get the decoded result for the combined result. */ std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text( - finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, - allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, + allow_multiple_lines); info("Complete recognition: %s\n", finalResultStr.c_str()); return true; |