summaryrefslogtreecommitdiff
path: root/source/use_case/kws_asr/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws_asr/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/kws_asr/src/UseCaseHandler.cc492
1 files changed, 150 insertions, 342 deletions
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index 1e1a400..01aefae 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -28,6 +28,7 @@
#include "Wav2LetterMfcc.hpp"
#include "Wav2LetterPreprocess.hpp"
#include "Wav2LetterPostprocess.hpp"
+#include "KwsProcessing.hpp"
#include "AsrResult.hpp"
#include "AsrClassifier.hpp"
#include "OutputDecode.hpp"
@@ -39,11 +40,6 @@ using KwsClassifier = arm::app::Classifier;
namespace arm {
namespace app {
- enum AsrOutputReductionAxis {
- AxisRow = 1,
- AxisCol = 2
- };
-
struct KWSOutput {
bool executionSuccess = false;
const int16_t* asrAudioStart = nullptr;
@@ -51,73 +47,53 @@ namespace app {
};
/**
- * @brief Presents kws inference results using the data presentation
- * object.
- * @param[in] results vector of classification results to be displayed
- * @return true if successful, false otherwise
+ * @brief Presents KWS inference results.
+ * @param[in] results Vector of KWS classification results to be displayed.
+ * @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results);
+ static bool PresentInferenceResult(std::vector<kws::KwsResult>& results);
/**
- * @brief Presents asr inference results using the data presentation
- * object.
- * @param[in] platform reference to the hal platform object
- * @param[in] results vector of classification results to be displayed
- * @return true if successful, false otherwise
+ * @brief Presents ASR inference results.
+ * @param[in] results Vector of ASR classification results to be displayed.
+ * @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results);
+ static bool PresentInferenceResult(std::vector<asr::AsrResult>& results);
/**
- * @brief Returns a function to perform feature calculation and populates input tensor data with
- * MFCC data.
- *
- * Input tensor data type check is performed to choose correct MFCC feature data type.
- * If tensor has an integer data type then original features are quantised.
- *
- * Warning: mfcc calculator provided as input must have the same life scope as returned function.
- *
- * @param[in] mfcc MFCC feature calculator.
- * @param[in,out] inputTensor Input tensor pointer to store calculated features.
- * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
- *
- * @return function function to be called providing audio sample and sliding window index.
+ * @brief Performs the KWS pipeline.
+ * @param[in,out] ctx pointer to the application context object
+ * @return struct containing pointer to audio data where ASR should begin
+ * and how much data to process.
**/
- static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::MicroNetMFCC& mfcc,
- TfLiteTensor* inputTensor,
- size_t cacheSize);
+ static KWSOutput doKws(ApplicationContext& ctx)
+ {
+ auto& profiler = ctx.Get<Profiler&>("profiler");
+ auto& kwsModel = ctx.Get<Model&>("kwsModel");
+ const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength");
+ const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride");
+ const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold");
+
+ auto currentIndex = ctx.Get<uint32_t>("clipIndex");
- /**
- * @brief Performs the KWS pipeline.
- * @param[in,out] ctx pointer to the application context object
- *
- * @return KWSOutput struct containing pointer to audio data where ASR should begin
- * and how much data to process.
- */
- static KWSOutput doKws(ApplicationContext& ctx) {
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
constexpr int minTensorDims = static_cast<int>(
- (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
- arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
+ (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)?
+ MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx);
- KWSOutput output;
+ /* Output struct from doing KWS. */
+ KWSOutput output {};
- auto& profiler = ctx.Get<Profiler&>("profiler");
- auto& kwsModel = ctx.Get<Model&>("kwsmodel");
if (!kwsModel.IsInited()) {
printf_err("KWS model has not been initialised\n");
return output;
}
- const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
- const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
- const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
-
- TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
+ /* Get Input and Output tensors for pre/post processing. */
TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
-
+ TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
if (!kwsInputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return output;
@@ -126,63 +102,32 @@ namespace app {
return output;
}
- const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
- const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
-
- audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
- kwsMfcc.Init();
-
- /* Deduce the data length required for 1 KWS inference from the network parameters. */
- auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
- (kwsFrameLength - kwsFrameStride);
- auto kwsMfccWindowSize = kwsFrameLength;
- auto kwsMfccWindowStride = kwsFrameStride;
-
- /* We are choosing to move by half the window size => for a 1 second window size,
- * this means an overlap of 0.5 seconds. */
- auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
-
- info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
-
- /* Stride must be multiple of mfcc features window stride to re-use features. */
- if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
- kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
- }
-
- auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
+ /* Get input shape for feature extraction. */
+ TfLiteIntArray* inputShape = kwsModel.GetInputShape(0);
+ const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
+ const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx];
/* We expect to be sampling 1 second worth of data at a time
* NOTE: This is only used for time stamp calculation. */
- const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
+ const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
- auto currentIndex = ctx.Get<uint32_t>("clipIndex");
+ /* Set up pre and post-processing. */
+ KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames,
+ kwsMfccFrameLength, kwsMfccFrameStride);
- /* Creating a mfcc features sliding window for the data required for 1 inference. */
- auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
- get_audio_array(currentIndex),
- kwsAudioDataWindowSize, kwsMfccWindowSize,
- kwsMfccWindowStride);
+ std::vector<ClassificationResult> singleInfResult;
+ KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"),
+ ctx.Get<std::vector<std::string>&>("kwsLabels"),
+ singleInfResult);
/* Creating a sliding window through the whole audio clip. */
auto audioDataSlider = audio::SlidingWindow<const int16_t>(
get_audio_array(currentIndex),
get_audio_array_size(currentIndex),
- kwsAudioDataWindowSize, kwsAudioDataStride);
-
- /* Calculate number of the feature vectors in the window overlap region.
- * These feature vectors will be reused.*/
- size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
- - kwsMfccVectorsInAudioStride;
+ preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride);
- auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
- numberOfReusedFeatureVectors);
-
- if (!kwsMfccFeatureCalc){
- return output;
- }
-
- /* Container for KWS results. */
- std::vector<arm::app::kws::KwsResult> kwsResults;
+ /* Declare a container to hold kws results from across the whole audio clip. */
+ std::vector<kws::KwsResult> finalResults;
/* Display message on the LCD - inference running. */
std::string str_inf{"Running KWS inference... "};
@@ -197,70 +142,56 @@ namespace app {
while (audioDataSlider.HasNext()) {
const int16_t* inferenceWindow = audioDataSlider.Next();
- /* We moved to the next window - set the features sliding to the new address. */
- kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
-
/* The first window does not have cache ready. */
- bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
-
- /* Start calculating features inside one audio sliding window. */
- while (kwsAudioMFCCWindowSlider.HasNext()) {
- const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
- std::vector<int16_t> kwsMfccAudioData =
- std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
-
- /* Compute features for this window and write them to input tensor. */
- kwsMfccFeatureCalc(kwsMfccAudioData,
- kwsAudioMFCCWindowSlider.Index(),
- useCache,
- kwsMfccVectorsInAudioStride);
- }
+ preProcess.m_audioWindowIndex = audioDataSlider.Index();
- info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
- audioDataSlider.TotalStrides() + 1);
+ /* Run the pre-processing, inference and post-processing. */
+ if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
+ printf_err("KWS Pre-processing failed.");
+ return output;
+ }
- /* Run inference over this audio clip sliding window. */
if (!RunInference(kwsModel, profiler)) {
- printf_err("KWS inference failed\n");
+ printf_err("KWS Inference failed.");
return output;
}
- std::vector<ClassificationResult> kwsClassificationResult;
- auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
+ if (!postProcess.DoPostProcess()) {
+ printf_err("KWS Post-processing failed.");
+ return output;
+ }
- kwsClassifier.GetClassificationResults(
- kwsOutputTensor, kwsClassificationResult,
- ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
+ info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
+ audioDataSlider.TotalStrides() + 1);
- kwsResults.emplace_back(
- kws::KwsResult(
- kwsClassificationResult,
- audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
- audioDataSlider.Index(), kwsScoreThreshold)
- );
+ /* Add results from this window to our final results vector. */
+ finalResults.emplace_back(
+ kws::KwsResult(singleInfResult,
+ audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride,
+ audioDataSlider.Index(), kwsScoreThreshold));
- /* Keyword detected. */
- if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) {
- output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
+ /* Break out when trigger keyword is detected. */
+ if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword")
+ && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) {
+ output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize;
output.asrAudioSamples = get_audio_array_size(currentIndex) -
(audioDataSlider.NextWindowStartIndex() -
- kwsAudioDataStride + kwsAudioDataWindowSize);
+ preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize);
break;
}
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(kwsOutputTensor);
+ DumpTensor(kwsOutputTensor);
#endif /* VERIFY_TEST_OUTPUT */
} /* while (audioDataSlider.HasNext()) */
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
- if (!PresentInferenceResult(kwsResults)) {
+ if (!PresentInferenceResult(finalResults)) {
return output;
}
@@ -271,41 +202,41 @@ namespace app {
}
/**
- * @brief Performs the ASR pipeline.
- *
- * @param[in,out] ctx pointer to the application context object
- * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
- * and how much data to process
- * @return bool true if pipeline executed without failure
- */
- static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
+ * @brief Performs the ASR pipeline.
+ * @param[in,out] ctx Pointer to the application context object.
+ * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin
+ * and how much data to process.
+ * @return true if pipeline executed without failure.
+ **/
+ static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput)
+ {
+ auto& asrModel = ctx.Get<Model&>("asrModel");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
+ auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength");
+ auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride");
+ auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold");
+ auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
+
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
- auto& profiler = ctx.Get<Profiler&>("profiler");
- hal_lcd_clear(COLOR_BLACK);
-
- /* Get model reference. */
- auto& asrModel = ctx.Get<Model&>("asrmodel");
if (!asrModel.IsInited()) {
printf_err("ASR model has not been initialised\n");
return false;
}
- /* Get score threshold to be applied for the classifier (post-inference). */
- auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
+ hal_lcd_clear(COLOR_BLACK);
- /* Dimensions of the tensor should have been verified by the callee. */
+ /* Get Input and Output tensors for pre/post processing. */
TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
- const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
- /* Populate ASR MFCC related parameters. */
- auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
- auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
+ /* Get input shape. Dimensions of the tensor should have been verified by
+ * the callee. */
+ TfLiteIntArray* inputShape = asrModel.GetInputShape(0);
- /* Populate ASR inference context and inner lengths for input. */
- auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
+
+ const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx];
const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
/* Make sure the input tensor supports the above context and inner lengths. */
@@ -316,18 +247,9 @@ namespace app {
}
/* Audio data stride corresponds to inputInnerLen feature vectors. */
- const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
- asrMfccParamsWinStride + (asrMfccParamsWinLen);
- const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
- const float asrAudioParamsSecondsPerSample =
- (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
-
- /* Get pre/post-processing objects */
- auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
- auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
-
- /* Set default reduction axis for post-processing. */
- const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
+ const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen);
+ const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride;
+ const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq;
/* Get the remaining audio buffer and respective size from KWS results. */
const int16_t* audioArr = kwsOutput.asrAudioStart;
@@ -335,9 +257,9 @@ namespace app {
/* Audio clip must have enough samples to produce 1 MFCC feature. */
std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
- if (audioArrSize < asrMfccParamsWinLen) {
+ if (audioArrSize < asrMfccFrameLen) {
printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
- asrMfccParamsWinLen);
+ asrMfccFrameLen);
return false;
}
@@ -345,26 +267,38 @@ namespace app {
auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
audioBuffer.data(),
audioBuffer.size(),
- asrAudioParamsWinLen,
- asrAudioParamsWinStride);
+ asrAudioDataWindowLen,
+ asrAudioDataWindowStride);
/* Declare a container for results. */
- std::vector<arm::app::asr::AsrResult> asrResults;
+ std::vector<asr::AsrResult> asrResults;
/* Display message on the LCD - inference running. */
std::string str_inf{"Running ASR inference... "};
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
- size_t asrInferenceWindowLen = asrAudioParamsWinLen;
-
+ size_t asrInferenceWindowLen = asrAudioDataWindowLen;
+
+ /* Set up pre and post-processing objects. */
+ AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures,
+ inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
+ asrMfccFrameLen, asrMfccFrameStride);
+
+ std::vector<ClassificationResult> singleInfResult;
+ const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen);
+ AsrPostProcess asrPostProcess = AsrPostProcess(
+ asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"),
+ ctx.Get<std::vector<std::string>&>("asrLabels"),
+ singleInfResult, outputCtxLen,
+ Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
+ );
/* Start sliding through audio clip. */
while (audioDataSlider.HasNext()) {
/* If not enough audio see how much can be sent for processing. */
size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
- if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
+ if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) {
asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
}
@@ -373,8 +307,11 @@ namespace app {
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
- /* Calculate MFCCs, deltas and populate the input tensor. */
- asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
+ /* Run the pre-processing, inference and post-processing. */
+ if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) {
+ printf_err("ASR pre-processing failed.");
+ return false;
+ }
/* Run inference over this audio clip sliding window. */
if (!RunInference(asrModel, profiler)) {
@@ -382,24 +319,28 @@ namespace app {
return false;
}
- /* Post-process. */
- asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
+ /* Post processing needs to know if we are on the last audio window. */
+ asrPostProcess.m_lastIteration = !audioDataSlider.HasNext();
+ if (!asrPostProcess.DoPostProcess()) {
+ printf_err("ASR post-processing failed.");
+ return false;
+ }
/* Get results. */
std::vector<ClassificationResult> asrClassificationResult;
- auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
+ auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier");
asrClassifier.GetClassificationResults(
asrOutputTensor, asrClassificationResult,
- ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
+ ctx.Get<std::vector<std::string>&>("asrLabels"), 1);
asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
(audioDataSlider.Index() *
asrAudioParamsSecondsPerSample *
- asrAudioParamsWinStride),
+ asrAudioDataWindowStride),
audioDataSlider.Index(), asrScoreThreshold));
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
+ armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
#endif /* VERIFY_TEST_OUTPUT */
/* Erase */
@@ -417,7 +358,7 @@ namespace app {
return true;
}
- /* Audio inference classification handler. */
+ /* KWS and ASR inference handler. */
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
hal_lcd_clear(COLOR_BLACK);
@@ -434,13 +375,14 @@ namespace app {
do {
KWSOutput kwsOutput = doKws(ctx);
if (!kwsOutput.executionSuccess) {
+ printf_err("KWS failed\n");
return false;
}
if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
- info("Keyword spotted\n");
+ info("Trigger keyword spotted\n");
if(!doAsr(ctx, kwsOutput)) {
- printf_err("ASR failed");
+ printf_err("ASR failed\n");
return false;
}
}
@@ -452,7 +394,6 @@ namespace app {
return true;
}
-
static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
{
constexpr uint32_t dataPsnTxtStartX1 = 20;
@@ -464,33 +405,31 @@ namespace app {
/* Display each result. */
uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
- for (uint32_t i = 0; i < results.size(); ++i) {
-
+ for (auto & result : results) {
std::string topKeyword{"<none>"};
float score = 0.f;
- if (!results[i].m_resultVec.empty()) {
- topKeyword = results[i].m_resultVec[0].m_label;
- score = results[i].m_resultVec[0].m_normalisedVal;
+ if (!result.m_resultVec.empty()) {
+ topKeyword = result.m_resultVec[0].m_label;
+ score = result.m_resultVec[0].m_normalisedVal;
}
std::string resultStr =
- std::string{"@"} + std::to_string(results[i].m_timeStamp) +
+ std::string{"@"} + std::to_string(result.m_timeStamp) +
std::string{"s: "} + topKeyword + std::string{" ("} +
std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
- hal_lcd_display_text(
- resultStr.c_str(), resultStr.size(),
- dataPsnTxtStartX1, rowIdx1, 0);
+ hal_lcd_display_text(resultStr.c_str(), resultStr.size(),
+ dataPsnTxtStartX1, rowIdx1, 0);
rowIdx1 += dataPsnTxtYIncr;
info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
- results[i].m_timeStamp, results[i].m_inferenceNumber,
- results[i].m_threshold);
- for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
+ result.m_timeStamp, result.m_inferenceNumber,
+ result.m_threshold);
+ for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
- results[i].m_resultVec[j].m_label.c_str(),
- results[i].m_resultVec[j].m_normalisedVal);
+ result.m_resultVec[j].m_label.c_str(),
+ result.m_resultVec[j].m_normalisedVal);
}
}
@@ -523,143 +462,12 @@ namespace app {
std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
- hal_lcd_display_text(
- finalResultStr.c_str(), finalResultStr.size(),
- dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
+ hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
+ dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
info("Final result: %s\n", finalResultStr.c_str());
return true;
}
- /**
- * @brief Generic feature calculator factory.
- *
- * Returns lambda function to compute features using features cache.
- * Real features math is done by a lambda function provided as a parameter.
- * Features are written to input tensor memory.
- *
- * @tparam T feature vector type.
- * @param inputTensor model input tensor pointer.
- * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
- * @param compute features calculator function.
- * @return lambda function to compute features.
- **/
- template<class T>
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
- FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
- std::function<std::vector<T> (std::vector<int16_t>& )> compute)
- {
- /* Feature cache to be captured by lambda function. */
- static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
-
- return [=](std::vector<int16_t>& audioDataWindow,
- size_t index,
- bool useCache,
- size_t featuresOverlapIndex)
- {
- T* tensorData = tflite::GetTensorData<T>(inputTensor);
- std::vector<T> features;
-
- /* Reuse features from cache if cache is ready and sliding windows overlap.
- * Overlap is in the beginning of sliding window with a size of a feature cache.
- */
- if (useCache && index < featureCache.size()) {
- features = std::move(featureCache[index]);
- } else {
- features = std::move(compute(audioDataWindow));
- }
- auto size = features.size();
- auto sizeBytes = sizeof(T) * size;
- std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
-
- /* Start renewing cache as soon iteration goes out of the windows overlap. */
- if (index >= featuresOverlapIndex) {
- featureCache[index - featuresOverlapIndex] = std::move(features);
- }
- };
- }
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
-
- template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
- FeatureCalc<float>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
- static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
- {
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
-
- TfLiteQuantization quant = inputTensor->quantization;
-
- if (kTfLiteAffineQuantization == quant.type) {
-
- auto* quantParams = (TfLiteAffineQuantization*) quant.params;
- const float quantScale = quantParams->scale->data[0];
- const int quantOffset = quantParams->zero_point->data[0];
-
- switch (inputTensor->type) {
- case kTfLiteInt8: {
- mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
- cacheSize,
- [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
- quantScale,
- quantOffset);
- }
- );
- break;
- }
- case kTfLiteUInt8: {
- mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
- cacheSize,
- [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
- quantScale,
- quantOffset);
- }
- );
- break;
- }
- case kTfLiteInt16: {
- mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
- cacheSize,
- [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
- quantScale,
- quantOffset);
- }
- );
- break;
- }
- default:
- printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
- }
-
-
- } else {
- mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
- cacheSize,
- [&mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccCompute(audioDataWindow);
- });
- }
- return mfccFeatureCalc;
- }
} /* namespace app */
} /* namespace arm */ \ No newline at end of file