summaryrefslogtreecommitdiff
path: root/source/use_case/asr/src/UseCaseHandler.cc
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2022-12-19 16:37:33 +0000
committerKshitij Sisodia <kshitij.sisodia@arm.com>2022-12-19 17:05:29 +0000
commit2ea46232a15aaf7600f1b92314612f4aa2fc6cd2 (patch)
tree7c05c514c3bbe932a067067b719d46ff16e5c2e7 /source/use_case/asr/src/UseCaseHandler.cc
parent9a97134ee00125c7a406cbf57c3ba8360df8f980 (diff)
downloadml-embedded-evaluation-kit-2ea46232a15aaf7600f1b92314612f4aa2fc6cd2.tar.gz
MLECO-3611: Formatting fixes for generated files.
Template files updated for generated files to adhere to coding guidelines and clang format configuration. There will still be unavoidable violations, but most of the others have been fixed. Change-Id: Ia03db40f8c62a369f2b07fe02eea65e41993a523 Signed-off-by: Kshitij Sisodia <kshitij.sisodia@arm.com>
Diffstat (limited to 'source/use_case/asr/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/asr/src/UseCaseHandler.cc107
1 files changed, 57 insertions, 50 deletions
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index 76409b6..d13a03a 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -1,6 +1,6 @@
/*
- * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
- * SPDX-License-Identifier: Apache-2.0
+ * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
+ * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,17 +16,17 @@
*/
#include "UseCaseHandler.hpp"
-#include "InputFiles.hpp"
#include "AsrClassifier.hpp"
-#include "Wav2LetterModel.hpp"
-#include "hal.h"
+#include "AsrResult.hpp"
#include "AudioUtils.hpp"
#include "ImageUtils.hpp"
+#include "InputFiles.hpp"
+#include "OutputDecode.hpp"
#include "UseCaseCommonUtils.hpp"
-#include "AsrResult.hpp"
-#include "Wav2LetterPreprocess.hpp"
+#include "Wav2LetterModel.hpp"
#include "Wav2LetterPostprocess.hpp"
-#include "OutputDecode.hpp"
+#include "Wav2LetterPreprocess.hpp"
+#include "hal.h"
#include "log_macros.h"
namespace arm {
@@ -42,19 +42,19 @@ namespace app {
/* ASR inference handler. */
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
- auto& model = ctx.Get<Model&>("model");
- auto& profiler = ctx.Get<Profiler&>("profiler");
- auto mfccFrameLen = ctx.Get<uint32_t>("frameLength");
+ auto& model = ctx.Get<Model&>("model");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
+ auto mfccFrameLen = ctx.Get<uint32_t>("frameLength");
auto mfccFrameStride = ctx.Get<uint32_t>("frameStride");
- auto scoreThreshold = ctx.Get<float>("scoreThreshold");
- auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
+ auto scoreThreshold = ctx.Get<float>("scoreThreshold");
+ auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
/* If the request has a valid size, set the audio index. */
if (clipIndex < NUMBER_OF_FILES) {
- if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
+ if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) {
return false;
}
}
- auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
+ auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
@@ -63,7 +63,7 @@ namespace app {
return false;
}
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
TfLiteTensor* outputTensor = model.GetOutputTensor(0);
/* Get input shape. Dimensions of the tensor should have been verified by
@@ -81,18 +81,21 @@ namespace app {
const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
/* Set up pre and post-processing objects. */
- AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures,
+ AsrPreProcess preProcess = AsrPreProcess(inputTensor,
+ Wav2LetterModel::ms_numMfccFeatures,
inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
- mfccFrameLen, mfccFrameStride);
+ mfccFrameLen,
+ mfccFrameStride);
std::vector<ClassificationResult> singleInfResult;
const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen);
- AsrPostProcess postProcess = AsrPostProcess(
- outputTensor, ctx.Get<AsrClassifier&>("classifier"),
- ctx.Get<std::vector<std::string>&>("labels"),
- singleInfResult, outputCtxLen,
- Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
- );
+ AsrPostProcess postProcess = AsrPostProcess(outputTensor,
+ ctx.Get<AsrClassifier&>("classifier"),
+ ctx.Get<std::vector<std::string>&>("labels"),
+ singleInfResult,
+ outputCtxLen,
+ Wav2LetterModel::ms_blankTokenIdx,
+ Wav2LetterModel::ms_outputRowsIdx);
/* Loop to process audio clips. */
do {
@@ -102,8 +105,8 @@ namespace app {
auto currentIndex = ctx.Get<uint32_t>("clipIndex");
/* Get the current audio buffer and respective size. */
- const int16_t* audioArr = get_audio_array(currentIndex);
- const uint32_t audioArrSize = get_audio_array_size(currentIndex);
+ const int16_t* audioArr = GetAudioArray(currentIndex);
+ const uint32_t audioArrSize = GetAudioArraySize(currentIndex);
if (!audioArr) {
printf_err("Invalid audio array pointer.\n");
@@ -119,19 +122,19 @@ namespace app {
/* Creating a sliding window through the whole audio clip. */
auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
- audioArr, audioArrSize,
- audioDataWindowLen, audioDataWindowStride);
+ audioArr, audioArrSize, audioDataWindowLen, audioDataWindowStride);
/* Declare a container for final results. */
std::vector<asr::AsrResult> finalResults;
/* Display message on the LCD - inference running. */
std::string str_inf{"Running inference... "};
- hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ hal_lcd_display_text(
+ str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
- info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
- get_filename(currentIndex));
+ info("Running inference on audio clip %" PRIu32 " => %s\n",
+ currentIndex,
+ GetFilename(currentIndex));
size_t inferenceWindowLen = audioDataWindowLen;
@@ -146,7 +149,8 @@ namespace app {
const int16_t* inferenceWindow = audioDataSlider.Next();
- info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
+ info("Inference %zu/%zu\n",
+ audioDataSlider.Index() + 1,
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
/* Run the pre-processing, inference and post-processing. */
@@ -168,20 +172,22 @@ namespace app {
}
/* Add results from this window to our final results vector. */
- finalResults.emplace_back(asr::AsrResult(singleInfResult,
- (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride),
- audioDataSlider.Index(), scoreThreshold));
+ finalResults.emplace_back(asr::AsrResult(
+ singleInfResult,
+ (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride),
+ audioDataSlider.Index(),
+ scoreThreshold));
#if VERIFY_TEST_OUTPUT
armDumpTensor(outputTensor,
- outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
-#endif /* VERIFY_TEST_OUTPUT */
+ outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
+#endif /* VERIFY_TEST_OUTPUT */
} /* while (audioDataSlider.HasNext()) */
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ hal_lcd_display_text(
+ str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
ctx.Set<std::vector<asr::AsrResult>>("results", finalResults);
@@ -191,19 +197,18 @@ namespace app {
profiler.PrintProfilingResult();
- IncrementAppCtxIfmIdx(ctx,"clipIndex");
+ IncrementAppCtxIfmIdx(ctx, "clipIndex");
} while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx);
return true;
}
-
static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results)
{
constexpr uint32_t dataPsnTxtStartX1 = 20;
constexpr uint32_t dataPsnTxtStartY1 = 60;
- constexpr bool allow_multiple_lines = true;
+ constexpr bool allow_multiple_lines = true;
hal_lcd_set_text_color(COLOR_GREEN);
@@ -212,9 +217,8 @@ namespace app {
/* Results from multiple inferences should be combined before processing. */
std::vector<ClassificationResult> combinedResults;
for (const auto& result : results) {
- combinedResults.insert(combinedResults.end(),
- result.m_resultVec.begin(),
- result.m_resultVec.end());
+ combinedResults.insert(
+ combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end());
}
/* Get each inference result string using the decoder. */
@@ -222,16 +226,19 @@ namespace app {
std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n",
- result.m_timeStamp, result.m_inferenceNumber,
+ result.m_timeStamp,
+ result.m_inferenceNumber,
infResultStr.c_str());
}
/* Get the decoded result for the combined result. */
std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
- hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
- dataPsnTxtStartX1, dataPsnTxtStartY1,
- allow_multiple_lines);
+ hal_lcd_display_text(finalResultStr.c_str(),
+ finalResultStr.size(),
+ dataPsnTxtStartX1,
+ dataPsnTxtStartY1,
+ allow_multiple_lines);
info("Complete recognition: %s\n", finalResultStr.c_str());
return true;