diff options
Diffstat (limited to 'source/use_case/kws_asr/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/kws_asr/src/UseCaseHandler.cc | 245 |
1 files changed, 132 insertions, 113 deletions
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index e733605..8a024b7 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,35 +16,34 @@ */ #include "UseCaseHandler.hpp" -#include "hal.h" -#include "InputFiles.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" #include "AudioUtils.hpp" -#include "ImageUtils.hpp" -#include "UseCaseCommonUtils.hpp" -#include "MicroNetKwsModel.hpp" -#include "MicroNetKwsMfcc.hpp" #include "Classifier.hpp" +#include "ImageUtils.hpp" +#include "InputFiles.hpp" +#include "KwsProcessing.hpp" #include "KwsResult.hpp" -#include "Wav2LetterModel.hpp" +#include "MicroNetKwsMfcc.hpp" +#include "MicroNetKwsModel.hpp" +#include "OutputDecode.hpp" +#include "UseCaseCommonUtils.hpp" #include "Wav2LetterMfcc.hpp" -#include "Wav2LetterPreprocess.hpp" +#include "Wav2LetterModel.hpp" #include "Wav2LetterPostprocess.hpp" -#include "KwsProcessing.hpp" -#include "AsrResult.hpp" -#include "AsrClassifier.hpp" -#include "OutputDecode.hpp" +#include "Wav2LetterPreprocess.hpp" +#include "hal.h" #include "log_macros.h" - using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { struct KWSOutput { - bool executionSuccess = false; + bool executionSuccess = false; const int16_t* asrAudioStart = nullptr; - int32_t asrAudioSamples = 0; + int32_t asrAudioSamples = 0; }; /** @@ -69,23 +68,24 @@ namespace app { **/ static KWSOutput doKws(ApplicationContext& ctx) { - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& kwsModel = ctx.Get<Model&>("kwsModel"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& kwsModel = ctx.Get<Model&>("kwsModel"); const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength"); const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride"); - const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold"); + const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold"); auto currentIndex = ctx.Get<uint32_t>("clipIndex"); constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - constexpr int minTensorDims = static_cast<int>( - (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? - MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); + constexpr int minTensorDims = + static_cast<int>((MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx) + ? MicroNetKwsModel::ms_inputRowsIdx + : MicroNetKwsModel::ms_inputColsIdx); /* Output struct from doing KWS. */ - KWSOutput output {}; + KWSOutput output{}; if (!kwsModel.IsInited()) { printf_err("KWS model has not been initialised\n"); @@ -93,7 +93,7 @@ namespace app { } /* Get Input and Output tensors for pre/post processing. */ - TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); + TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); if (!kwsInputTensor->dims) { printf_err("Invalid input tensor dims\n"); @@ -104,28 +104,30 @@ namespace app { } /* Get input shape for feature extraction. */ - TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); + TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; - const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; + const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = + 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; /* Set up pre and post-processing. */ - KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames, - kwsMfccFrameLength, kwsMfccFrameStride); + KwsPreProcess preProcess = KwsPreProcess( + kwsInputTensor, numMfccFeatures, numMfccFrames, kwsMfccFrameLength, kwsMfccFrameStride); std::vector<ClassificationResult> singleInfResult; - KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"), + KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, + ctx.Get<KwsClassifier&>("kwsClassifier"), ctx.Get<std::vector<std::string>&>("kwsLabels"), singleInfResult); /* Creating a sliding window through the whole audio clip. */ - auto audioDataSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); + auto audioDataSlider = audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex), + GetAudioArraySize(currentIndex), + preProcess.m_audioDataWindowSize, + preProcess.m_audioDataStride); /* Declare a container to hold kws results from across the whole audio clip. */ std::vector<kws::KwsResult> finalResults; @@ -133,11 +135,11 @@ namespace app { /* Display message on the LCD - inference running. */ std::string str_inf{"Running KWS inference... "}; hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); info("Running KWS inference on audio clip %" PRIu32 " => %s\n", - currentIndex, get_filename(currentIndex)); + currentIndex, + GetFilename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { @@ -159,22 +161,26 @@ namespace app { return output; } - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Add results from this window to our final results vector. */ finalResults.emplace_back( - kws::KwsResult(singleInfResult, - audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride, - audioDataSlider.Index(), kwsScoreThreshold)); + kws::KwsResult(singleInfResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * + preProcess.m_audioDataStride, + audioDataSlider.Index(), + kwsScoreThreshold)); /* Break out when trigger keyword is detected. */ - if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") - && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { + if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") && + singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize; - output.asrAudioSamples = get_audio_array_size(currentIndex) - - (audioDataSlider.NextWindowStartIndex() - - preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize); + output.asrAudioSamples = + GetAudioArraySize(currentIndex) - + (audioDataSlider.NextWindowStartIndex() - preProcess.m_audioDataStride + + preProcess.m_audioDataWindowSize); break; } @@ -186,8 +192,8 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); if (!PresentInferenceResult(finalResults)) { return output; @@ -208,12 +214,12 @@ namespace app { **/ static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { - auto& asrModel = ctx.Get<Model&>("asrModel"); - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength"); + auto& asrModel = ctx.Get<Model&>("asrModel"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength"); auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride"); - auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold"); - auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold"); + auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -226,31 +232,32 @@ namespace app { hal_lcd_clear(COLOR_BLACK); /* Get Input and Output tensors for pre/post processing. */ - TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); + TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); /* Get input shape. Dimensions of the tensor should have been verified by - * the callee. */ + * the callee. */ TfLiteIntArray* inputShape = asrModel.GetInputShape(0); - const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx]; const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); /* Make sure the input tensor supports the above context and inner lengths. */ if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) { printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n", - asrInputCtxLen); + asrInputCtxLen); return false; } /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); + const uint32_t asrAudioDataWindowLen = + (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride; - const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; + const float asrAudioParamsSecondsPerSample = + 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; /* Get the remaining audio buffer and respective size from KWS results. */ - const int16_t* audioArr = kwsOutput.asrAudioStart; + const int16_t* audioArr = kwsOutput.asrAudioStart; const uint32_t audioArrSize = kwsOutput.asrAudioSamples; /* Audio clip must have enough samples to produce 1 MFCC feature. */ @@ -262,35 +269,40 @@ namespace app { } /* Initialise an audio slider. */ - auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( - audioBuffer.data(), - audioBuffer.size(), - asrAudioDataWindowLen, - asrAudioDataWindowStride); + auto audioDataSlider = + audio::FractionalSlidingWindow<const int16_t>(audioBuffer.data(), + audioBuffer.size(), + asrAudioDataWindowLen, + asrAudioDataWindowStride); /* Declare a container for results. */ std::vector<asr::AsrResult> asrResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running ASR inference... "}; - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); size_t asrInferenceWindowLen = asrAudioDataWindowLen; /* Set up pre and post-processing objects. */ - AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures, - inputShape->data[Wav2LetterModel::ms_inputRowsIdx], - asrMfccFrameLen, asrMfccFrameStride); + AsrPreProcess asrPreProcess = + AsrPreProcess(asrInputTensor, + arm::app::Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + asrMfccFrameLen, + asrMfccFrameStride); std::vector<ClassificationResult> singleInfResult; const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen); - AsrPostProcess asrPostProcess = AsrPostProcess( - asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"), - ctx.Get<std::vector<std::string>&>("asrLabels"), - singleInfResult, outputCtxLen, - Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx - ); + AsrPostProcess asrPostProcess = + AsrPostProcess(asrOutputTensor, + ctx.Get<AsrClassifier&>("asrClassifier"), + ctx.Get<std::vector<std::string>&>("asrLabels"), + singleInfResult, + outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, + Wav2LetterModel::ms_outputRowsIdx); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { @@ -302,8 +314,9 @@ namespace app { const int16_t* asrInferenceWindow = audioDataSlider.Next(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, - static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, + static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); /* Run the pre-processing, inference and post-processing. */ if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) { @@ -327,25 +340,27 @@ namespace app { /* Get results. */ std::vector<ClassificationResult> asrClassificationResult; auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier"); - asrClassifier.GetClassificationResults( - asrOutputTensor, asrClassificationResult, - ctx.Get<std::vector<std::string>&>("asrLabels"), 1); - - asrResults.emplace_back(asr::AsrResult(asrClassificationResult, - (audioDataSlider.Index() * - asrAudioParamsSecondsPerSample * - asrAudioDataWindowStride), - audioDataSlider.Index(), asrScoreThreshold)); + asrClassifier.GetClassificationResults(asrOutputTensor, + asrClassificationResult, + ctx.Get<std::vector<std::string>&>("asrLabels"), + 1); + + asrResults.emplace_back( + asr::AsrResult(asrClassificationResult, + (audioDataSlider.Index() * asrAudioParamsSecondsPerSample * + asrAudioDataWindowStride), + audioDataSlider.Index(), + asrScoreThreshold)); #if VERIFY_TEST_OUTPUT - armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); + armDumpTensor(asrOutputTensor, + asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ /* Erase */ str_inf = std::string(str_inf.size(), ' '); hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); } if (!PresentInferenceResult(asrResults)) { return false; @@ -363,7 +378,7 @@ namespace app { /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) { + if (!SetAppCtxIfmIdx(ctx, clipIndex, "kws_asr")) { return false; } } @@ -379,13 +394,13 @@ namespace app { if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { info("Trigger keyword spotted\n"); - if(!doAsr(ctx, kwsOutput)) { + if (!doAsr(ctx, kwsOutput)) { printf_err("ASR failed\n"); return false; } } - IncrementAppCtxIfmIdx(ctx,"kws_asr"); + IncrementAppCtxIfmIdx(ctx, "kws_asr"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); @@ -396,36 +411,38 @@ namespace app { { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 30; - constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ hal_lcd_set_text_color(COLOR_GREEN); /* Display each result. */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (auto & result : results) { + for (auto& result : results) { std::string topKeyword{"<none>"}; float score = 0.f; if (!result.m_resultVec.empty()) { topKeyword = result.m_resultVec[0].m_label; - score = result.m_resultVec[0].m_normalisedVal; + score = result.m_resultVec[0].m_normalisedVal; } - std::string resultStr = - std::string{"@"} + std::to_string(result.m_timeStamp) + - std::string{"s: "} + topKeyword + std::string{" ("} + - std::to_string(static_cast<int>(score * 100)) + std::string{"%)"}; + std::string resultStr = std::string{"@"} + std::to_string(result.m_timeStamp) + + std::string{"s: "} + topKeyword + std::string{" ("} + + std::to_string(static_cast<int>(score * 100)) + + std::string{"%)"}; - hal_lcd_display_text(resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); + hal_lcd_display_text( + resultStr.c_str(), resultStr.size(), dataPsnTxtStartX1, rowIdx1, 0); rowIdx1 += dataPsnTxtYIncr; info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n", - result.m_timeStamp, result.m_inferenceNumber, + result.m_timeStamp, + result.m_inferenceNumber, result.m_threshold); for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { - info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j, + info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", + j, result.m_resultVec[j].m_label.c_str(), result.m_resultVec[j].m_normalisedVal); } @@ -438,30 +455,32 @@ namespace app { { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 80; - constexpr bool allow_multiple_lines = true; + constexpr bool allow_multiple_lines = true; hal_lcd_set_text_color(COLOR_GREEN); /* Results from multiple inferences should be combined before processing. */ std::vector<arm::app::ClassificationResult> combinedResults; for (auto& result : results) { - combinedResults.insert(combinedResults.end(), - result.m_resultVec.begin(), - result.m_resultVec.end()); + combinedResults.insert( + combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end()); } for (auto& result : results) { /* Get the final result string using the decoder. */ std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); - info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber, - infResultStr.c_str()); + info( + "Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber, infResultStr.c_str()); } std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), + finalResultStr.size(), + dataPsnTxtStartX1, + dataPsnTxtStartY1, + allow_multiple_lines); info("Final result: %s\n", finalResultStr.c_str()); return true; |