diff options
Diffstat (limited to 'source/use_case/asr/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/asr/src/UseCaseHandler.cc | 107 |
1 files changed, 57 insertions, 50 deletions
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 76409b6..d13a03a 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,17 @@ */ #include "UseCaseHandler.hpp" -#include "InputFiles.hpp" #include "AsrClassifier.hpp" -#include "Wav2LetterModel.hpp" -#include "hal.h" +#include "AsrResult.hpp" #include "AudioUtils.hpp" #include "ImageUtils.hpp" +#include "InputFiles.hpp" +#include "OutputDecode.hpp" #include "UseCaseCommonUtils.hpp" -#include "AsrResult.hpp" -#include "Wav2LetterPreprocess.hpp" +#include "Wav2LetterModel.hpp" #include "Wav2LetterPostprocess.hpp" -#include "OutputDecode.hpp" +#include "Wav2LetterPreprocess.hpp" +#include "hal.h" #include "log_macros.h" namespace arm { @@ -42,19 +42,19 @@ namespace app { /* ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - auto& model = ctx.Get<Model&>("model"); - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto mfccFrameLen = ctx.Get<uint32_t>("frameLength"); + auto& model = ctx.Get<Model&>("model"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto mfccFrameLen = ctx.Get<uint32_t>("frameLength"); auto mfccFrameStride = ctx.Get<uint32_t>("frameStride"); - auto scoreThreshold = ctx.Get<float>("scoreThreshold"); - auto inputCtxLen = ctx.Get<uint32_t>("ctxLen"); + auto scoreThreshold = ctx.Get<float>("scoreThreshold"); + auto inputCtxLen = ctx.Get<uint32_t>("ctxLen"); /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { + if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) { return false; } } - auto initialClipIdx = ctx.Get<uint32_t>("clipIndex"); + auto initialClipIdx = ctx.Get<uint32_t>("clipIndex"); constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -63,7 +63,7 @@ namespace app { return false; } - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); TfLiteTensor* outputTensor = model.GetOutputTensor(0); /* Get input shape. Dimensions of the tensor should have been verified by @@ -81,18 +81,21 @@ namespace app { const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq); /* Set up pre and post-processing objects. */ - AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures, + AsrPreProcess preProcess = AsrPreProcess(inputTensor, + Wav2LetterModel::ms_numMfccFeatures, inputShape->data[Wav2LetterModel::ms_inputRowsIdx], - mfccFrameLen, mfccFrameStride); + mfccFrameLen, + mfccFrameStride); std::vector<ClassificationResult> singleInfResult; const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen); - AsrPostProcess postProcess = AsrPostProcess( - outputTensor, ctx.Get<AsrClassifier&>("classifier"), - ctx.Get<std::vector<std::string>&>("labels"), - singleInfResult, outputCtxLen, - Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx - ); + AsrPostProcess postProcess = AsrPostProcess(outputTensor, + ctx.Get<AsrClassifier&>("classifier"), + ctx.Get<std::vector<std::string>&>("labels"), + singleInfResult, + outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, + Wav2LetterModel::ms_outputRowsIdx); /* Loop to process audio clips. */ do { @@ -102,8 +105,8 @@ namespace app { auto currentIndex = ctx.Get<uint32_t>("clipIndex"); /* Get the current audio buffer and respective size. */ - const int16_t* audioArr = get_audio_array(currentIndex); - const uint32_t audioArrSize = get_audio_array_size(currentIndex); + const int16_t* audioArr = GetAudioArray(currentIndex); + const uint32_t audioArrSize = GetAudioArraySize(currentIndex); if (!audioArr) { printf_err("Invalid audio array pointer.\n"); @@ -119,19 +122,19 @@ namespace app { /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( - audioArr, audioArrSize, - audioDataWindowLen, audioDataWindowStride); + audioArr, audioArrSize, audioDataWindowLen, audioDataWindowStride); /* Declare a container for final results. */ std::vector<asr::AsrResult> finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, - get_filename(currentIndex)); + info("Running inference on audio clip %" PRIu32 " => %s\n", + currentIndex, + GetFilename(currentIndex)); size_t inferenceWindowLen = audioDataWindowLen; @@ -146,7 +149,8 @@ namespace app { const int16_t* inferenceWindow = audioDataSlider.Next(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); /* Run the pre-processing, inference and post-processing. */ @@ -168,20 +172,22 @@ namespace app { } /* Add results from this window to our final results vector. */ - finalResults.emplace_back(asr::AsrResult(singleInfResult, - (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride), - audioDataSlider.Index(), scoreThreshold)); + finalResults.emplace_back(asr::AsrResult( + singleInfResult, + (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride), + audioDataSlider.Index(), + scoreThreshold)); #if VERIFY_TEST_OUTPUT armDumpTensor(outputTensor, - outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); -#endif /* VERIFY_TEST_OUTPUT */ + outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); +#endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); ctx.Set<std::vector<asr::AsrResult>>("results", finalResults); @@ -191,19 +197,18 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"clipIndex"); + IncrementAppCtxIfmIdx(ctx, "clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx); return true; } - static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 60; - constexpr bool allow_multiple_lines = true; + constexpr bool allow_multiple_lines = true; hal_lcd_set_text_color(COLOR_GREEN); @@ -212,9 +217,8 @@ namespace app { /* Results from multiple inferences should be combined before processing. */ std::vector<ClassificationResult> combinedResults; for (const auto& result : results) { - combinedResults.insert(combinedResults.end(), - result.m_resultVec.begin(), - result.m_resultVec.end()); + combinedResults.insert( + combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end()); } /* Get each inference result string using the decoder. */ @@ -222,16 +226,19 @@ namespace app { std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n", - result.m_timeStamp, result.m_inferenceNumber, + result.m_timeStamp, + result.m_inferenceNumber, infResultStr.c_str()); } /* Get the decoded result for the combined result. */ std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, - allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), + finalResultStr.size(), + dataPsnTxtStartX1, + dataPsnTxtStartY1, + allow_multiple_lines); info("Complete recognition: %s\n", finalResultStr.c_str()); return true; |