diff options
Diffstat (limited to 'source')
-rw-r--r-- | source/application/main/UseCaseCommonUtils.cc | 165 | ||||
-rw-r--r-- | source/use_case/ad/src/UseCaseHandler.cc | 89 | ||||
-rw-r--r-- | source/use_case/asr/src/UseCaseHandler.cc | 107 | ||||
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 64 | ||||
-rw-r--r-- | source/use_case/kws/src/UseCaseHandler.cc | 104 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/UseCaseHandler.cc | 245 | ||||
-rw-r--r-- | source/use_case/noise_reduction/src/UseCaseHandler.cc | 191 | ||||
-rw-r--r-- | source/use_case/object_detection/src/UseCaseHandler.cc | 116 | ||||
-rw-r--r-- | source/use_case/vww/src/UseCaseHandler.cc | 68 |
9 files changed, 618 insertions, 531 deletions
diff --git a/source/application/main/UseCaseCommonUtils.cc b/source/application/main/UseCaseCommonUtils.cc index d1276ab..7d0a2ca 100644 --- a/source/application/main/UseCaseCommonUtils.cc +++ b/source/application/main/UseCaseCommonUtils.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ #include <cinttypes> - void DisplayCommonMenu() { printf("\n\n"); @@ -36,7 +35,7 @@ void DisplayCommonMenu() fflush(stdout); } -bool PresentInferenceResult(const std::vector<arm::app::ClassificationResult> &results) +bool PresentInferenceResult(const std::vector<arm::app::ClassificationResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 150; constexpr uint32_t dataPsnTxtStartY1 = 30; @@ -44,7 +43,7 @@ bool PresentInferenceResult(const std::vector<arm::app::ClassificationResult> &r constexpr uint32_t dataPsnTxtStartX2 = 10; constexpr uint32_t dataPsnTxtStartY2 = 150; - constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ hal_lcd_set_text_color(COLOR_GREEN); @@ -56,31 +55,28 @@ bool PresentInferenceResult(const std::vector<arm::app::ClassificationResult> &r info("Total number of inferences: 1\n"); for (uint32_t i = 0; i < results.size(); ++i) { - std::string resultStr = - std::to_string(i + 1) + ") " + - std::to_string(results[i].m_labelIdx) + - " (" + std::to_string(results[i].m_normalisedVal) + ")"; + std::string resultStr = std::to_string(i + 1) + ") " + + std::to_string(results[i].m_labelIdx) + " (" + + std::to_string(results[i].m_normalisedVal) + ")"; hal_lcd_display_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, false); + resultStr.c_str(), resultStr.size(), dataPsnTxtStartX1, rowIdx1, false); rowIdx1 += dataPsnTxtYIncr; resultStr = std::to_string(i + 1) + ") " + results[i].m_label; - hal_lcd_display_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX2, rowIdx2, 0); + hal_lcd_display_text(resultStr.c_str(), resultStr.size(), dataPsnTxtStartX2, rowIdx2, 0); rowIdx2 += dataPsnTxtYIncr; - info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", i, - results[i].m_labelIdx, results[i].m_normalisedVal, - results[i].m_label.c_str()); + info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", + i, + results[i].m_labelIdx, + results[i].m_normalisedVal, + results[i].m_label.c_str()); } return true; } - void IncrementAppCtxIfmIdx(arm::app::ApplicationContext& ctx, const std::string& useCase) { #if NUMBER_OF_FILES > 0 @@ -92,7 +88,7 @@ void IncrementAppCtxIfmIdx(arm::app::ApplicationContext& ctx, const std::string& } ++curImIdx; ctx.Set<uint32_t>(useCase, curImIdx); -#else /* NUMBER_OF_FILES > 0 */ +#else /* NUMBER_OF_FILES > 0 */ UNUSED(ctx); UNUSED(useCase); #endif /* NUMBER_OF_FILES > 0 */ @@ -102,13 +98,12 @@ bool SetAppCtxIfmIdx(arm::app::ApplicationContext& ctx, uint32_t idx, const std: { #if NUMBER_OF_FILES > 0 if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); + printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", idx, NUMBER_OF_FILES); return false; } ctx.Set<uint32_t>(ctxIfmName, idx); return true; -#else /* NUMBER_OF_FILES > 0 */ +#else /* NUMBER_OF_FILES > 0 */ UNUSED(ctx); UNUSED(idx); UNUSED(ctxIfmName); @@ -119,83 +114,76 @@ bool SetAppCtxIfmIdx(arm::app::ApplicationContext& ctx, uint32_t idx, const std: namespace arm { namespace app { + bool RunInference(arm::app::Model& model, Profiler& profiler) + { + profiler.StartProfiling("Inference"); + bool runInf = model.RunInference(); + profiler.StopProfiling(); -bool RunInference(arm::app::Model& model, Profiler& profiler) -{ - profiler.StartProfiling("Inference"); - bool runInf = model.RunInference(); - profiler.StopProfiling(); - - return runInf; -} - -int ReadUserInputAsInt() -{ - char chInput[128]; - memset(chInput, 0, sizeof(chInput)); + return runInf; + } - hal_get_user_input(chInput, sizeof(chInput)); - return atoi(chInput); -} + int ReadUserInputAsInt() + { + char chInput[128]; + memset(chInput, 0, sizeof(chInput)); -void DumpTensorData(const uint8_t* tensorData, - size_t size, - size_t lineBreakForNumElements) -{ - char strhex[8]; - std::string strdump; + hal_get_user_input(chInput, sizeof(chInput)); + return atoi(chInput); + } - for (size_t i = 0; i < size; ++i) { - if (0 == i % lineBreakForNumElements) { - printf("%s\n\t", strdump.c_str()); - strdump.clear(); + void DumpTensorData(const uint8_t* tensorData, size_t size, size_t lineBreakForNumElements) + { + char strhex[8]; + std::string strdump; + + for (size_t i = 0; i < size; ++i) { + if (0 == i % lineBreakForNumElements) { + printf("%s\n\t", strdump.c_str()); + strdump.clear(); + } + snprintf(strhex, sizeof(strhex) - 1, "0x%02x, ", tensorData[i]); + strdump += std::string(strhex); } - snprintf(strhex, sizeof(strhex) - 1, - "0x%02x, ", tensorData[i]); - strdump += std::string(strhex); - } - if (!strdump.empty()) { - printf("%s\n", strdump.c_str()); + if (!strdump.empty()) { + printf("%s\n", strdump.c_str()); + } } -} -void DumpTensor(const TfLiteTensor* tensor, const size_t lineBreakForNumElements) -{ - if (!tensor) { - printf_err("invalid tensor\n"); - return; - } + void DumpTensor(const TfLiteTensor* tensor, const size_t lineBreakForNumElements) + { + if (!tensor) { + printf_err("invalid tensor\n"); + return; + } - const uint32_t tensorSz = tensor->bytes; - const auto* tensorData = tflite::GetTensorData<uint8_t>(tensor); + const uint32_t tensorSz = tensor->bytes; + const auto* tensorData = tflite::GetTensorData<uint8_t>(tensor); - DumpTensorData(tensorData, tensorSz, lineBreakForNumElements); -} + DumpTensorData(tensorData, tensorSz, lineBreakForNumElements); + } -bool ListFilesHandler(ApplicationContext& ctx) -{ - auto& model = ctx.Get<Model&>("model"); + bool ListFilesHandler(ApplicationContext& ctx) + { + auto& model = ctx.Get<Model&>("model"); - constexpr uint32_t dataPsnTxtStartX = 20; - constexpr uint32_t dataPsnTxtStartY = 40; + constexpr uint32_t dataPsnTxtStartX = 20; + constexpr uint32_t dataPsnTxtStartY = 40; - if (!model.IsInited()) { - printf_err("Model is not initialised! Terminating processing.\n"); - return false; - } + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } - /* Clear the LCD */ - hal_lcd_clear(COLOR_BLACK); + /* Clear the LCD */ + hal_lcd_clear(COLOR_BLACK); - /* Show the total number of embedded files. */ - std::string strNumFiles = std::string{"Total Number of Files: "} + - std::to_string(NUMBER_OF_FILES); - hal_lcd_display_text(strNumFiles.c_str(), - strNumFiles.size(), - dataPsnTxtStartX, - dataPsnTxtStartY, - false); + /* Show the total number of embedded files. */ + std::string strNumFiles = + std::string{"Total Number of Files: "} + std::to_string(NUMBER_OF_FILES); + hal_lcd_display_text( + strNumFiles.c_str(), strNumFiles.size(), dataPsnTxtStartX, dataPsnTxtStartY, false); #if NUMBER_OF_FILES > 0 constexpr uint32_t dataPsnTxtYIncr = 16; @@ -203,17 +191,16 @@ bool ListFilesHandler(ApplicationContext& ctx) uint32_t yVal = dataPsnTxtStartY + dataPsnTxtYIncr; for (uint32_t i = 0; i < NUMBER_OF_FILES; ++i, yVal += dataPsnTxtYIncr) { - std::string currentFilename{get_filename(i)}; - hal_lcd_display_text(currentFilename.c_str(), - currentFilename.size(), - dataPsnTxtStartX, yVal, false); + std::string currentFilename{GetFilename(i)}; + hal_lcd_display_text( + currentFilename.c_str(), currentFilename.size(), dataPsnTxtStartX, yVal, false); info("\t%" PRIu32 " => %s\n", i, currentFilename.c_str()); } #endif /* NUMBER_OF_FILES > 0 */ return true; -} + } } /* namespace app */ } /* namespace arm */ diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc index 222aaf3..c71fdeb 100644 --- a/source/use_case/ad/src/UseCaseHandler.cc +++ b/source/use_case/ad/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,16 @@ */ #include "UseCaseHandler.hpp" -#include "AdModel.hpp" -#include "InputFiles.hpp" -#include "Classifier.hpp" -#include "hal.h" #include "AdMelSpectrogram.hpp" +#include "AdModel.hpp" +#include "AdProcessing.hpp" #include "AudioUtils.hpp" +#include "Classifier.hpp" #include "ImageUtils.hpp" +#include "InputFiles.hpp" #include "UseCaseCommonUtils.hpp" +#include "hal.h" #include "log_macros.h" -#include "AdProcessing.hpp" namespace arm { namespace app { @@ -44,7 +44,7 @@ namespace app { * File name should be in format anything_goes_XX_here.wav * where XX is the machine ID e.g. 00, 02, 04 or 06 * @return AD model output index as 8 bit integer. - **/ + **/ static int8_t OutputIndexFromFileName(std::string wavFileName); /* Anomaly Detection inference handler */ @@ -57,7 +57,7 @@ namespace app { /* If the request has a valid size, set the audio index */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { + if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) { return false; } } @@ -66,26 +66,22 @@ namespace app { return false; } - auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& profiler = ctx.Get<Profiler&>("profiler"); const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength"); const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride"); - const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); - const auto trainingMean = ctx.Get<float>("trainingMean"); - auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); + const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); + const auto trainingMean = ctx.Get<float>("trainingMean"); + auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); TfLiteTensor* outputTensor = model.GetOutputTensor(0); - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; } - AdPreProcess preProcess{ - inputTensor, - melSpecFrameLength, - melSpecFrameStride, - trainingMean}; + AdPreProcess preProcess{inputTensor, melSpecFrameLength, melSpecFrameStride, trainingMean}; AdPostProcess postProcess{outputTensor}; @@ -95,17 +91,17 @@ namespace app { auto currentIndex = ctx.Get<uint32_t>("clipIndex"); /* Get the output index to look at based on id in the filename. */ - int8_t machineOutputIndex = OutputIndexFromFileName(get_filename(currentIndex)); + int8_t machineOutputIndex = OutputIndexFromFileName(GetFilename(currentIndex)); if (machineOutputIndex == -1) { return false; } /* Creating a sliding window through the whole audio clip. */ - auto audioDataSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - preProcess.GetAudioWindowSize(), - preProcess.GetAudioDataStride()); + auto audioDataSlider = + audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex), + GetAudioArraySize(currentIndex), + preProcess.GetAudioWindowSize(), + preProcess.GetAudioDataStride()); /* Result is an averaged sum over inferences. */ float result = 0; @@ -113,11 +109,11 @@ namespace app { /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); info("Running inference on audio clip %" PRIu32 " => %s\n", - currentIndex, get_filename(currentIndex)); + currentIndex, + GetFilename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { @@ -126,7 +122,8 @@ namespace app { preProcess.SetAudioWindowIndex(audioDataSlider.Index()); preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize()); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run inference over this audio clip sliding window */ @@ -139,7 +136,7 @@ namespace app { #if VERIFY_TEST_OUTPUT DumpTensor(outputTensor); -#endif /* VERIFY_TEST_OUTPUT */ +#endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Use average over whole clip as final score. */ @@ -148,8 +145,7 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); ctx.Set<float>("result", result); if (!PresentInferenceResult(result, scoreThreshold)) { @@ -158,7 +154,7 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"clipIndex"); + IncrementAppCtxIfmIdx(ctx, "clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); @@ -176,8 +172,10 @@ namespace app { /* Display each result */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - std::string anomalyScore = std::string{"Average anomaly score is: "} + std::to_string(result); - std::string anomalyThreshold = std::string("Anomaly threshold is: ") + std::to_string(threshold); + std::string anomalyScore = + std::string{"Average anomaly score is: "} + std::to_string(result); + std::string anomalyThreshold = + std::string("Anomaly threshold is: ") + std::to_string(threshold); std::string anomalyResult; if (result > threshold) { @@ -187,8 +185,7 @@ namespace app { } hal_lcd_display_text( - anomalyScore.c_str(), anomalyScore.size(), - dataPsnTxtStartX1, rowIdx1, false); + anomalyScore.c_str(), anomalyScore.size(), dataPsnTxtStartX1, rowIdx1, false); info("%s\n", anomalyScore.c_str()); info("%s\n", anomalyThreshold.c_str()); @@ -200,26 +197,28 @@ namespace app { static int8_t OutputIndexFromFileName(std::string wavFileName) { /* Filename is assumed in the form machine_id_00.wav */ - std::string delimiter = "_"; /* First character used to split the file name up. */ + std::string delimiter = "_"; /* First character used to split the file name up. */ size_t delimiterStart; std::string subString; - size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ + size_t machineIdxInString = + 3; /* Which part of the file name the machine id should be at. */ for (size_t i = 0; i < machineIdxInString; ++i) { delimiterStart = wavFileName.find(delimiter); - subString = wavFileName.substr(0, delimiterStart); + subString = wavFileName.substr(0, delimiterStart); wavFileName.erase(0, delimiterStart + delimiter.length()); } /* At this point substring should be 00.wav */ - delimiter = "."; /* Second character used to split the file name up. */ + delimiter = "."; /* Second character used to split the file name up. */ delimiterStart = subString.find(delimiter); - subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; + subString = + (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - auto is_number = [](const std::string& str) -> bool - { + auto is_number = [](const std::string& str) -> bool { std::string::const_iterator it = str.begin(); - while (it != str.end() && std::isdigit(*it)) ++it; + while (it != str.end() && std::isdigit(*it)) + ++it; return !str.empty() && it == str.end(); }; diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 76409b6..d13a03a 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,17 @@ */ #include "UseCaseHandler.hpp" -#include "InputFiles.hpp" #include "AsrClassifier.hpp" -#include "Wav2LetterModel.hpp" -#include "hal.h" +#include "AsrResult.hpp" #include "AudioUtils.hpp" #include "ImageUtils.hpp" +#include "InputFiles.hpp" +#include "OutputDecode.hpp" #include "UseCaseCommonUtils.hpp" -#include "AsrResult.hpp" -#include "Wav2LetterPreprocess.hpp" +#include "Wav2LetterModel.hpp" #include "Wav2LetterPostprocess.hpp" -#include "OutputDecode.hpp" +#include "Wav2LetterPreprocess.hpp" +#include "hal.h" #include "log_macros.h" namespace arm { @@ -42,19 +42,19 @@ namespace app { /* ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - auto& model = ctx.Get<Model&>("model"); - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto mfccFrameLen = ctx.Get<uint32_t>("frameLength"); + auto& model = ctx.Get<Model&>("model"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto mfccFrameLen = ctx.Get<uint32_t>("frameLength"); auto mfccFrameStride = ctx.Get<uint32_t>("frameStride"); - auto scoreThreshold = ctx.Get<float>("scoreThreshold"); - auto inputCtxLen = ctx.Get<uint32_t>("ctxLen"); + auto scoreThreshold = ctx.Get<float>("scoreThreshold"); + auto inputCtxLen = ctx.Get<uint32_t>("ctxLen"); /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { + if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) { return false; } } - auto initialClipIdx = ctx.Get<uint32_t>("clipIndex"); + auto initialClipIdx = ctx.Get<uint32_t>("clipIndex"); constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -63,7 +63,7 @@ namespace app { return false; } - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); TfLiteTensor* outputTensor = model.GetOutputTensor(0); /* Get input shape. Dimensions of the tensor should have been verified by @@ -81,18 +81,21 @@ namespace app { const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq); /* Set up pre and post-processing objects. */ - AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures, + AsrPreProcess preProcess = AsrPreProcess(inputTensor, + Wav2LetterModel::ms_numMfccFeatures, inputShape->data[Wav2LetterModel::ms_inputRowsIdx], - mfccFrameLen, mfccFrameStride); + mfccFrameLen, + mfccFrameStride); std::vector<ClassificationResult> singleInfResult; const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen); - AsrPostProcess postProcess = AsrPostProcess( - outputTensor, ctx.Get<AsrClassifier&>("classifier"), - ctx.Get<std::vector<std::string>&>("labels"), - singleInfResult, outputCtxLen, - Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx - ); + AsrPostProcess postProcess = AsrPostProcess(outputTensor, + ctx.Get<AsrClassifier&>("classifier"), + ctx.Get<std::vector<std::string>&>("labels"), + singleInfResult, + outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, + Wav2LetterModel::ms_outputRowsIdx); /* Loop to process audio clips. */ do { @@ -102,8 +105,8 @@ namespace app { auto currentIndex = ctx.Get<uint32_t>("clipIndex"); /* Get the current audio buffer and respective size. */ - const int16_t* audioArr = get_audio_array(currentIndex); - const uint32_t audioArrSize = get_audio_array_size(currentIndex); + const int16_t* audioArr = GetAudioArray(currentIndex); + const uint32_t audioArrSize = GetAudioArraySize(currentIndex); if (!audioArr) { printf_err("Invalid audio array pointer.\n"); @@ -119,19 +122,19 @@ namespace app { /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( - audioArr, audioArrSize, - audioDataWindowLen, audioDataWindowStride); + audioArr, audioArrSize, audioDataWindowLen, audioDataWindowStride); /* Declare a container for final results. */ std::vector<asr::AsrResult> finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, - get_filename(currentIndex)); + info("Running inference on audio clip %" PRIu32 " => %s\n", + currentIndex, + GetFilename(currentIndex)); size_t inferenceWindowLen = audioDataWindowLen; @@ -146,7 +149,8 @@ namespace app { const int16_t* inferenceWindow = audioDataSlider.Next(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); /* Run the pre-processing, inference and post-processing. */ @@ -168,20 +172,22 @@ namespace app { } /* Add results from this window to our final results vector. */ - finalResults.emplace_back(asr::AsrResult(singleInfResult, - (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride), - audioDataSlider.Index(), scoreThreshold)); + finalResults.emplace_back(asr::AsrResult( + singleInfResult, + (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride), + audioDataSlider.Index(), + scoreThreshold)); #if VERIFY_TEST_OUTPUT armDumpTensor(outputTensor, - outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); -#endif /* VERIFY_TEST_OUTPUT */ + outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); +#endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); ctx.Set<std::vector<asr::AsrResult>>("results", finalResults); @@ -191,19 +197,18 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"clipIndex"); + IncrementAppCtxIfmIdx(ctx, "clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx); return true; } - static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 60; - constexpr bool allow_multiple_lines = true; + constexpr bool allow_multiple_lines = true; hal_lcd_set_text_color(COLOR_GREEN); @@ -212,9 +217,8 @@ namespace app { /* Results from multiple inferences should be combined before processing. */ std::vector<ClassificationResult> combinedResults; for (const auto& result : results) { - combinedResults.insert(combinedResults.end(), - result.m_resultVec.begin(), - result.m_resultVec.end()); + combinedResults.insert( + combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end()); } /* Get each inference result string using the decoder. */ @@ -222,16 +226,19 @@ namespace app { std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n", - result.m_timeStamp, result.m_inferenceNumber, + result.m_timeStamp, + result.m_inferenceNumber, infResultStr.c_str()); } /* Get the decoded result for the combined result. */ std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, - allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), + finalResultStr.size(), + dataPsnTxtStartX1, + dataPsnTxtStartY1, + allow_multiple_lines); info("Complete recognition: %s\n", finalResultStr.c_str()); return true; diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index 4732064..52c42f3 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,13 +17,13 @@ #include "UseCaseHandler.hpp" #include "Classifier.hpp" +#include "ImageUtils.hpp" +#include "ImgClassProcessing.hpp" #include "InputFiles.hpp" #include "MobileNetModel.hpp" -#include "ImageUtils.hpp" #include "UseCaseCommonUtils.hpp" #include "hal.h" #include "log_macros.h" -#include "ImgClassProcessing.hpp" #include <cinttypes> @@ -36,7 +36,7 @@ namespace app { bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& model = ctx.Get<Model&>("model"); + auto& model = ctx.Get<Model&>("model"); /* If the request has a valid size, set the image index as it might not be set. */ if (imgIndex < NUMBER_OF_FILES) { if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) { @@ -46,19 +46,18 @@ namespace app { auto initialImgIdx = ctx.Get<uint32_t>("imgIndex"); constexpr uint32_t dataPsnImgDownscaleFactor = 2; - constexpr uint32_t dataPsnImgStartX = 10; - constexpr uint32_t dataPsnImgStartY = 35; + constexpr uint32_t dataPsnImgStartX = 10; + constexpr uint32_t dataPsnImgStartY = 35; constexpr uint32_t dataPsnTxtInfStartX = 150; constexpr uint32_t dataPsnTxtInfStartY = 40; - if (!model.IsInited()) { printf_err("Model is not initialised! Terminating processing.\n"); return false; } - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); @@ -70,17 +69,19 @@ namespace app { /* Get input shape for displaying the image. */ TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx]; - const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx]; + const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx]; + const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx]; const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; /* Set up pre and post-processing. */ ImgClassPreProcess preProcess = ImgClassPreProcess(inputTensor, model.IsDataSigned()); std::vector<ClassificationResult> results; - ImgClassPostProcess postProcess = ImgClassPostProcess(outputTensor, - ctx.Get<ImgClassClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), - results); + ImgClassPostProcess postProcess = + ImgClassPostProcess(outputTensor, + ctx.Get<ImgClassClassifier&>("classifier"), + ctx.Get<std::vector<std::string>&>("labels"), + results); do { hal_lcd_clear(COLOR_BLACK); @@ -88,29 +89,34 @@ namespace app { /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; - const uint8_t* imgSrc = get_img_array(ctx.Get<uint32_t>("imgIndex")); + const uint8_t* imgSrc = GetImgArray(ctx.Get<uint32_t>("imgIndex")); if (nullptr == imgSrc) { - printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", ctx.Get<uint32_t>("imgIndex"), + printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", + ctx.Get<uint32_t>("imgIndex"), NUMBER_OF_FILES - 1); return false; } /* Display this image on the LCD. */ - hal_lcd_display_image( - imgSrc, - nCols, nRows, nChannels, - dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); + hal_lcd_display_image(imgSrc, + nCols, + nRows, + nChannels, + dataPsnImgStartX, + dataPsnImgStartY, + dataPsnImgDownscaleFactor); /* Display message on the LCD - inference running. */ - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); /* Select the image to run inference with. */ - info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"), - get_filename(ctx.Get<uint32_t>("imgIndex"))); + info("Running inference on image %" PRIu32 " => %s\n", + ctx.Get<uint32_t>("imgIndex"), + GetFilename(ctx.Get<uint32_t>("imgIndex"))); - const size_t imgSz = inputTensor->bytes < IMAGE_DATA_SIZE ? - inputTensor->bytes : IMAGE_DATA_SIZE; + const size_t imgSz = + inputTensor->bytes < IMAGE_DATA_SIZE ? inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ if (!preProcess.DoPreProcess(imgSrc, imgSz)) { @@ -130,8 +136,8 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); /* Add results to context for access outside handler. */ ctx.Set<std::vector<ClassificationResult>>("results", results); @@ -146,7 +152,7 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"imgIndex"); + IncrementAppCtxIfmIdx(ctx, "imgIndex"); } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImgIdx); diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index c20c32b..ce99ed3 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,16 @@ */ #include "UseCaseHandler.hpp" +#include "AudioUtils.hpp" +#include "ImageUtils.hpp" #include "InputFiles.hpp" #include "KwsClassifier.hpp" +#include "KwsProcessing.hpp" +#include "KwsResult.hpp" #include "MicroNetKwsModel.hpp" -#include "hal.h" -#include "AudioUtils.hpp" -#include "ImageUtils.hpp" #include "UseCaseCommonUtils.hpp" -#include "KwsResult.hpp" +#include "hal.h" #include "log_macros.h" -#include "KwsProcessing.hpp" #include <vector> @@ -42,15 +42,15 @@ namespace app { /* KWS inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& model = ctx.Get<Model&>("model"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& model = ctx.Get<Model&>("model"); const auto mfccFrameLength = ctx.Get<int>("frameLength"); const auto mfccFrameStride = ctx.Get<int>("frameStride"); - const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); + const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { + if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) { return false; } } @@ -58,9 +58,10 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - constexpr int minTensorDims = static_cast<int>( - (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? - MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); + constexpr int minTensorDims = + static_cast<int>((MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx) + ? MicroNetKwsModel::ms_inputRowsIdx + : MicroNetKwsModel::ms_inputColsIdx); if (!model.IsInited()) { printf_err("Model is not initialised! Terminating processing.\n"); @@ -68,7 +69,7 @@ namespace app { } /* Get Input and Output tensors for pre/post processing. */ - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); @@ -79,20 +80,22 @@ namespace app { } /* Get input shape for feature extraction. */ - TfLiteIntArray* inputShape = model.GetInputShape(0); + TfLiteIntArray* inputShape = model.GetInputShape(0); const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; - const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; + const uint32_t numMfccFrames = + inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time. * NOTE: This is only used for time stamp calculation. */ const float secondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; /* Set up pre and post-processing. */ - KwsPreProcess preProcess = KwsPreProcess(inputTensor, numMfccFeatures, numMfccFrames, - mfccFrameLength, mfccFrameStride); + KwsPreProcess preProcess = KwsPreProcess( + inputTensor, numMfccFeatures, numMfccFrames, mfccFrameLength, mfccFrameStride); std::vector<ClassificationResult> singleInfResult; - KwsPostProcess postProcess = KwsPostProcess(outputTensor, ctx.Get<KwsClassifier &>("classifier"), + KwsPostProcess postProcess = KwsPostProcess(outputTensor, + ctx.Get<KwsClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), singleInfResult); @@ -103,26 +106,29 @@ namespace app { auto currentIndex = ctx.Get<uint32_t>("clipIndex"); /* Creating a sliding window through the whole audio clip. */ - auto audioDataSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); + auto audioDataSlider = + audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex), + GetAudioArraySize(currentIndex), + preProcess.m_audioDataWindowSize, + preProcess.m_audioDataStride); /* Declare a container to hold results from across the whole audio clip. */ std::vector<kws::KwsResult> finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, - get_filename(currentIndex)); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + info("Running inference on audio clip %" PRIu32 " => %s\n", + currentIndex, + GetFilename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run the pre-processing, inference and post-processing. */ @@ -142,19 +148,21 @@ namespace app { } /* Add results from this window to our final results vector. */ - finalResults.emplace_back(kws::KwsResult(singleInfResult, - audioDataSlider.Index() * secondsPerSample * preProcess.m_audioDataStride, - audioDataSlider.Index(), scoreThreshold)); + finalResults.emplace_back(kws::KwsResult( + singleInfResult, + audioDataSlider.Index() * secondsPerSample * preProcess.m_audioDataStride, + audioDataSlider.Index(), + scoreThreshold)); #if VERIFY_TEST_OUTPUT DumpTensor(outputTensor); -#endif /* VERIFY_TEST_OUTPUT */ +#endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); ctx.Set<std::vector<kws::KwsResult>>("results", finalResults); @@ -164,7 +172,7 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"clipIndex"); + IncrementAppCtxIfmIdx(ctx, "clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx); @@ -175,7 +183,7 @@ namespace app { { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 30; - constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ hal_lcd_set_text_color(COLOR_GREEN); info("Final results:\n"); @@ -190,28 +198,28 @@ namespace app { float score = 0.f; if (!result.m_resultVec.empty()) { topKeyword = result.m_resultVec[0].m_label; - score = result.m_resultVec[0].m_normalisedVal; + score = result.m_resultVec[0].m_normalisedVal; } - std::string resultStr = - std::string{"@"} + std::to_string(result.m_timeStamp) + - std::string{"s: "} + topKeyword + std::string{" ("} + - std::to_string(static_cast<int>(score * 100)) + std::string{"%)"}; + std::string resultStr = std::string{"@"} + std::to_string(result.m_timeStamp) + + std::string{"s: "} + topKeyword + std::string{" ("} + + std::to_string(static_cast<int>(score * 100)) + + std::string{"%)"}; - hal_lcd_display_text(resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, false); + hal_lcd_display_text( + resultStr.c_str(), resultStr.size(), dataPsnTxtStartX1, rowIdx1, false); rowIdx1 += dataPsnTxtYIncr; if (result.m_resultVec.empty()) { - info("For timestamp: %f (inference #: %" PRIu32 - "); label: %s; threshold: %f\n", - result.m_timeStamp, result.m_inferenceNumber, + info("For timestamp: %f (inference #: %" PRIu32 "); label: %s; threshold: %f\n", + result.m_timeStamp, + result.m_inferenceNumber, topKeyword.c_str(), result.m_threshold); } else { for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { info("For timestamp: %f (inference #: %" PRIu32 - "); label: %s, score: %f; threshold: %f\n", + "); label: %s, score: %f; threshold: %f\n", result.m_timeStamp, result.m_inferenceNumber, result.m_resultVec[j].m_label.c_str(), diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index e733605..8a024b7 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,35 +16,34 @@ */ #include "UseCaseHandler.hpp" -#include "hal.h" -#include "InputFiles.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" #include "AudioUtils.hpp" -#include "ImageUtils.hpp" -#include "UseCaseCommonUtils.hpp" -#include "MicroNetKwsModel.hpp" -#include "MicroNetKwsMfcc.hpp" #include "Classifier.hpp" +#include "ImageUtils.hpp" +#include "InputFiles.hpp" +#include "KwsProcessing.hpp" #include "KwsResult.hpp" -#include "Wav2LetterModel.hpp" +#include "MicroNetKwsMfcc.hpp" +#include "MicroNetKwsModel.hpp" +#include "OutputDecode.hpp" +#include "UseCaseCommonUtils.hpp" #include "Wav2LetterMfcc.hpp" -#include "Wav2LetterPreprocess.hpp" +#include "Wav2LetterModel.hpp" #include "Wav2LetterPostprocess.hpp" -#include "KwsProcessing.hpp" -#include "AsrResult.hpp" -#include "AsrClassifier.hpp" -#include "OutputDecode.hpp" +#include "Wav2LetterPreprocess.hpp" +#include "hal.h" #include "log_macros.h" - using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { struct KWSOutput { - bool executionSuccess = false; + bool executionSuccess = false; const int16_t* asrAudioStart = nullptr; - int32_t asrAudioSamples = 0; + int32_t asrAudioSamples = 0; }; /** @@ -69,23 +68,24 @@ namespace app { **/ static KWSOutput doKws(ApplicationContext& ctx) { - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& kwsModel = ctx.Get<Model&>("kwsModel"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& kwsModel = ctx.Get<Model&>("kwsModel"); const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength"); const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride"); - const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold"); + const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold"); auto currentIndex = ctx.Get<uint32_t>("clipIndex"); constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - constexpr int minTensorDims = static_cast<int>( - (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? - MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); + constexpr int minTensorDims = + static_cast<int>((MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx) + ? MicroNetKwsModel::ms_inputRowsIdx + : MicroNetKwsModel::ms_inputColsIdx); /* Output struct from doing KWS. */ - KWSOutput output {}; + KWSOutput output{}; if (!kwsModel.IsInited()) { printf_err("KWS model has not been initialised\n"); @@ -93,7 +93,7 @@ namespace app { } /* Get Input and Output tensors for pre/post processing. */ - TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); + TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); if (!kwsInputTensor->dims) { printf_err("Invalid input tensor dims\n"); @@ -104,28 +104,30 @@ namespace app { } /* Get input shape for feature extraction. */ - TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); + TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; - const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; + const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = + 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; /* Set up pre and post-processing. */ - KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames, - kwsMfccFrameLength, kwsMfccFrameStride); + KwsPreProcess preProcess = KwsPreProcess( + kwsInputTensor, numMfccFeatures, numMfccFrames, kwsMfccFrameLength, kwsMfccFrameStride); std::vector<ClassificationResult> singleInfResult; - KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"), + KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, + ctx.Get<KwsClassifier&>("kwsClassifier"), ctx.Get<std::vector<std::string>&>("kwsLabels"), singleInfResult); /* Creating a sliding window through the whole audio clip. */ - auto audioDataSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); + auto audioDataSlider = audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex), + GetAudioArraySize(currentIndex), + preProcess.m_audioDataWindowSize, + preProcess.m_audioDataStride); /* Declare a container to hold kws results from across the whole audio clip. */ std::vector<kws::KwsResult> finalResults; @@ -133,11 +135,11 @@ namespace app { /* Display message on the LCD - inference running. */ std::string str_inf{"Running KWS inference... "}; hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); info("Running KWS inference on audio clip %" PRIu32 " => %s\n", - currentIndex, get_filename(currentIndex)); + currentIndex, + GetFilename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { @@ -159,22 +161,26 @@ namespace app { return output; } - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Add results from this window to our final results vector. */ finalResults.emplace_back( - kws::KwsResult(singleInfResult, - audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride, - audioDataSlider.Index(), kwsScoreThreshold)); + kws::KwsResult(singleInfResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * + preProcess.m_audioDataStride, + audioDataSlider.Index(), + kwsScoreThreshold)); /* Break out when trigger keyword is detected. */ - if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") - && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { + if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") && + singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize; - output.asrAudioSamples = get_audio_array_size(currentIndex) - - (audioDataSlider.NextWindowStartIndex() - - preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize); + output.asrAudioSamples = + GetAudioArraySize(currentIndex) - + (audioDataSlider.NextWindowStartIndex() - preProcess.m_audioDataStride + + preProcess.m_audioDataWindowSize); break; } @@ -186,8 +192,8 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); if (!PresentInferenceResult(finalResults)) { return output; @@ -208,12 +214,12 @@ namespace app { **/ static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { - auto& asrModel = ctx.Get<Model&>("asrModel"); - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength"); + auto& asrModel = ctx.Get<Model&>("asrModel"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength"); auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride"); - auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold"); - auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold"); + auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -226,31 +232,32 @@ namespace app { hal_lcd_clear(COLOR_BLACK); /* Get Input and Output tensors for pre/post processing. */ - TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); + TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); /* Get input shape. Dimensions of the tensor should have been verified by - * the callee. */ + * the callee. */ TfLiteIntArray* inputShape = asrModel.GetInputShape(0); - const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx]; const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); /* Make sure the input tensor supports the above context and inner lengths. */ if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) { printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n", - asrInputCtxLen); + asrInputCtxLen); return false; } /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); + const uint32_t asrAudioDataWindowLen = + (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride; - const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; + const float asrAudioParamsSecondsPerSample = + 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; /* Get the remaining audio buffer and respective size from KWS results. */ - const int16_t* audioArr = kwsOutput.asrAudioStart; + const int16_t* audioArr = kwsOutput.asrAudioStart; const uint32_t audioArrSize = kwsOutput.asrAudioSamples; /* Audio clip must have enough samples to produce 1 MFCC feature. */ @@ -262,35 +269,40 @@ namespace app { } /* Initialise an audio slider. */ - auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( - audioBuffer.data(), - audioBuffer.size(), - asrAudioDataWindowLen, - asrAudioDataWindowStride); + auto audioDataSlider = + audio::FractionalSlidingWindow<const int16_t>(audioBuffer.data(), + audioBuffer.size(), + asrAudioDataWindowLen, + asrAudioDataWindowStride); /* Declare a container for results. */ std::vector<asr::AsrResult> asrResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running ASR inference... "}; - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); size_t asrInferenceWindowLen = asrAudioDataWindowLen; /* Set up pre and post-processing objects. */ - AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures, - inputShape->data[Wav2LetterModel::ms_inputRowsIdx], - asrMfccFrameLen, asrMfccFrameStride); + AsrPreProcess asrPreProcess = + AsrPreProcess(asrInputTensor, + arm::app::Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + asrMfccFrameLen, + asrMfccFrameStride); std::vector<ClassificationResult> singleInfResult; const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen); - AsrPostProcess asrPostProcess = AsrPostProcess( - asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"), - ctx.Get<std::vector<std::string>&>("asrLabels"), - singleInfResult, outputCtxLen, - Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx - ); + AsrPostProcess asrPostProcess = + AsrPostProcess(asrOutputTensor, + ctx.Get<AsrClassifier&>("asrClassifier"), + ctx.Get<std::vector<std::string>&>("asrLabels"), + singleInfResult, + outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, + Wav2LetterModel::ms_outputRowsIdx); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { @@ -302,8 +314,9 @@ namespace app { const int16_t* asrInferenceWindow = audioDataSlider.Next(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, - static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, + static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); /* Run the pre-processing, inference and post-processing. */ if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) { @@ -327,25 +340,27 @@ namespace app { /* Get results. */ std::vector<ClassificationResult> asrClassificationResult; auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier"); - asrClassifier.GetClassificationResults( - asrOutputTensor, asrClassificationResult, - ctx.Get<std::vector<std::string>&>("asrLabels"), 1); - - asrResults.emplace_back(asr::AsrResult(asrClassificationResult, - (audioDataSlider.Index() * - asrAudioParamsSecondsPerSample * - asrAudioDataWindowStride), - audioDataSlider.Index(), asrScoreThreshold)); + asrClassifier.GetClassificationResults(asrOutputTensor, + asrClassificationResult, + ctx.Get<std::vector<std::string>&>("asrLabels"), + 1); + + asrResults.emplace_back( + asr::AsrResult(asrClassificationResult, + (audioDataSlider.Index() * asrAudioParamsSecondsPerSample * + asrAudioDataWindowStride), + audioDataSlider.Index(), + asrScoreThreshold)); #if VERIFY_TEST_OUTPUT - armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); + armDumpTensor(asrOutputTensor, + asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ /* Erase */ str_inf = std::string(str_inf.size(), ' '); hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); } if (!PresentInferenceResult(asrResults)) { return false; @@ -363,7 +378,7 @@ namespace app { /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) { + if (!SetAppCtxIfmIdx(ctx, clipIndex, "kws_asr")) { return false; } } @@ -379,13 +394,13 @@ namespace app { if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { info("Trigger keyword spotted\n"); - if(!doAsr(ctx, kwsOutput)) { + if (!doAsr(ctx, kwsOutput)) { printf_err("ASR failed\n"); return false; } } - IncrementAppCtxIfmIdx(ctx,"kws_asr"); + IncrementAppCtxIfmIdx(ctx, "kws_asr"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); @@ -396,36 +411,38 @@ namespace app { { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 30; - constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ hal_lcd_set_text_color(COLOR_GREEN); /* Display each result. */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (auto & result : results) { + for (auto& result : results) { std::string topKeyword{"<none>"}; float score = 0.f; if (!result.m_resultVec.empty()) { topKeyword = result.m_resultVec[0].m_label; - score = result.m_resultVec[0].m_normalisedVal; + score = result.m_resultVec[0].m_normalisedVal; } - std::string resultStr = - std::string{"@"} + std::to_string(result.m_timeStamp) + - std::string{"s: "} + topKeyword + std::string{" ("} + - std::to_string(static_cast<int>(score * 100)) + std::string{"%)"}; + std::string resultStr = std::string{"@"} + std::to_string(result.m_timeStamp) + + std::string{"s: "} + topKeyword + std::string{" ("} + + std::to_string(static_cast<int>(score * 100)) + + std::string{"%)"}; - hal_lcd_display_text(resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); + hal_lcd_display_text( + resultStr.c_str(), resultStr.size(), dataPsnTxtStartX1, rowIdx1, 0); rowIdx1 += dataPsnTxtYIncr; info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n", - result.m_timeStamp, result.m_inferenceNumber, + result.m_timeStamp, + result.m_inferenceNumber, result.m_threshold); for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { - info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j, + info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", + j, result.m_resultVec[j].m_label.c_str(), result.m_resultVec[j].m_normalisedVal); } @@ -438,30 +455,32 @@ namespace app { { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 80; - constexpr bool allow_multiple_lines = true; + constexpr bool allow_multiple_lines = true; hal_lcd_set_text_color(COLOR_GREEN); /* Results from multiple inferences should be combined before processing. */ std::vector<arm::app::ClassificationResult> combinedResults; for (auto& result : results) { - combinedResults.insert(combinedResults.end(), - result.m_resultVec.begin(), - result.m_resultVec.end()); + combinedResults.insert( + combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end()); } for (auto& result : results) { /* Get the final result string using the decoder. */ std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); - info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber, - infResultStr.c_str()); + info( + "Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber, infResultStr.c_str()); } std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), + finalResultStr.size(), + dataPsnTxtStartX1, + dataPsnTxtStartY1, + allow_multiple_lines); info("Final result: %s\n", finalResultStr.c_str()); return true; diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc index 0aef600..0c5ff39 100644 --- a/source/use_case/noise_reduction/src/UseCaseHandler.cc +++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,24 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" #include "UseCaseHandler.hpp" -#include "UseCaseCommonUtils.hpp" #include "AudioUtils.hpp" #include "ImageUtils.hpp" #include "InputFiles.hpp" -#include "RNNoiseModel.hpp" #include "RNNoiseFeatureProcessor.hpp" +#include "RNNoiseModel.hpp" #include "RNNoiseProcessing.hpp" +#include "UseCaseCommonUtils.hpp" +#include "hal.h" #include "log_macros.h" namespace arm { namespace app { /** - * @brief Helper function to increment current audio clip features index. - * @param[in,out] ctx Pointer to the application context object. - **/ + * @brief Helper function to increment current audio clip features index. + * @param[in,out] ctx Pointer to the application context object. + **/ static void IncrementAppCtxClipIdx(ApplicationContext& ctx); /* Noise reduction inference handler. */ @@ -41,17 +41,18 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartY = 40; /* Variables used for memory dumping. */ - size_t memDumpMaxLen = 0; - uint8_t* memDumpBaseAddr = nullptr; + size_t memDumpMaxLen = 0; + uint8_t* memDumpBaseAddr = nullptr; size_t undefMemDumpBytesWritten = 0; - size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten; - if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) { - memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN"); - memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR"); + size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten; + if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && + ctx.Has("MEM_DUMP_BYTE_WRITTEN")) { + memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN"); + memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR"); pMemDumpBytesWritten = ctx.Get<size_t*>("MEM_DUMP_BYTE_WRITTEN"); } std::reference_wrapper<size_t> memDumpBytesWritten = std::ref(*pMemDumpBytesWritten); - auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& profiler = ctx.Get<Profiler&>("profiler"); /* Get model reference. */ auto& model = ctx.Get<RNNoiseModel&>("model"); @@ -61,15 +62,16 @@ namespace app { } /* Populate Pre-Processing related parameters. */ - auto audioFrameLen = ctx.Get<uint32_t>("frameLength"); - auto audioFrameStride = ctx.Get<uint32_t>("frameStride"); + auto audioFrameLen = ctx.Get<uint32_t>("frameLength"); + auto audioFrameStride = ctx.Get<uint32_t>("frameStride"); auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures"); TfLiteTensor* inputTensor = model.GetInputTensor(0); if (nrNumInputFeatures != inputTensor->bytes) { printf_err("Input features size must be equal to input tensor size." " Feature size = %" PRIu32 ", Tensor size = %zu.\n", - nrNumInputFeatures, inputTensor->bytes); + nrNumInputFeatures, + inputTensor->bytes); return false; } @@ -78,49 +80,55 @@ namespace app { /* Initial choice of index for WAV file. */ auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); - std::function<const int16_t* (const uint32_t)> audioAccessorFunc = get_audio_array; + std::function<const int16_t*(const uint32_t)> audioAccessorFunc = GetAudioArray; if (ctx.Has("features")) { - audioAccessorFunc = ctx.Get<std::function<const int16_t* (const uint32_t)>>("features"); + audioAccessorFunc = ctx.Get<std::function<const int16_t*(const uint32_t)>>("features"); } - std::function<uint32_t (const uint32_t)> audioSizeAccessorFunc = get_audio_array_size; + std::function<uint32_t(const uint32_t)> audioSizeAccessorFunc = GetAudioArraySize; if (ctx.Has("featureSizes")) { - audioSizeAccessorFunc = ctx.Get<std::function<uint32_t (const uint32_t)>>("featureSizes"); + audioSizeAccessorFunc = + ctx.Get<std::function<uint32_t(const uint32_t)>>("featureSizes"); } - std::function<const char*(const uint32_t)> audioFileAccessorFunc = get_filename; + std::function<const char*(const uint32_t)> audioFileAccessorFunc = GetFilename; if (ctx.Has("featureFileNames")) { - audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames"); + audioFileAccessorFunc = + ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames"); } do { hal_lcd_clear(COLOR_BLACK); auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten; - auto currentIndex = ctx.Get<uint32_t>("clipIndex"); + auto currentIndex = ctx.Get<uint32_t>("clipIndex"); /* Creating a sliding window through the audio. */ - auto audioDataSlider = audio::SlidingWindow<const int16_t>( - audioAccessorFunc(currentIndex), - audioSizeAccessorFunc(currentIndex), audioFrameLen, - audioFrameStride); - - info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex, + auto audioDataSlider = + audio::SlidingWindow<const int16_t>(audioAccessorFunc(currentIndex), + audioSizeAccessorFunc(currentIndex), + audioFrameLen, + audioFrameStride); + + info("Running inference on input feature map %" PRIu32 " => %s\n", + currentIndex, audioFileAccessorFunc(currentIndex)); - memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex), - (audioDataSlider.TotalStrides() + 1) * audioFrameLen, - memDumpBaseAddr + memDumpBytesWritten, - memDumpMaxLen - memDumpBytesWritten); + memDumpBytesWritten += + DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex), + (audioDataSlider.TotalStrides() + 1) * audioFrameLen, + memDumpBaseAddr + memDumpBytesWritten, + memDumpMaxLen - memDumpBytesWritten); /* Set up pre and post-processing. */ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor = - std::make_shared<rnn::RNNoiseFeatureProcessor>(); + std::make_shared<rnn::RNNoiseFeatureProcessor>(); std::shared_ptr<rnn::FrameFeatures> frameFeatures = - std::make_shared<rnn::FrameFeatures>(); + std::make_shared<rnn::FrameFeatures>(); - RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures); + RNNoisePreProcess preProcess = + RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures); std::vector<int16_t> denoisedAudioFrame(audioFrameLen); - RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame, - featureProcessor, frameFeatures); + RNNoisePostProcess postProcess = RNNoisePostProcess( + outputTensor, denoisedAudioFrame, featureProcessor, frameFeatures); bool resetGRU = true; @@ -133,11 +141,12 @@ namespace app { } /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */ - if (resetGRU){ + if (resetGRU) { model.ResetGruState(); } else { /* Copying gru state outputs to gru state inputs. - * Call ResetGruState in between the sequence of inferences on unrelated input data. */ + * Call ResetGruState in between the sequence of inferences on unrelated input + * data. */ model.CopyGruStates(); } @@ -145,10 +154,15 @@ namespace app { std::string str_inf{"Running inference... "}; /* Display message on the LCD - inference running. */ - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), + str_inf.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); /* Run inference over this feature sliding window. */ if (!RunInference(model, profiler)) { @@ -165,15 +179,18 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), + str_inf.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); if (memDumpMaxLen > 0) { /* Dump final post processed output to memory. */ - memDumpBytesWritten += DumpOutputDenoisedAudioFrame( - denoisedAudioFrame, - memDumpBaseAddr + memDumpBytesWritten, - memDumpMaxLen - memDumpBytesWritten); + memDumpBytesWritten += + DumpOutputDenoisedAudioFrame(denoisedAudioFrame, + memDumpBaseAddr + memDumpBytesWritten, + memDumpMaxLen - memDumpBytesWritten); } } @@ -181,43 +198,54 @@ namespace app { /* Needed to not let the compiler complain about type mismatch. */ size_t valMemDumpBytesWritten = memDumpBytesWritten; info("Output memory dump of %zu bytes written at address 0x%p\n", - valMemDumpBytesWritten, startDumpAddress); + valMemDumpBytesWritten, + startDumpAddress); } /* Finish by dumping the footer. */ - DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); + DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, + memDumpMaxLen - memDumpBytesWritten); info("All inferences for audio clip complete.\n"); profiler.PrintProfilingResult(); IncrementAppCtxClipIdx(ctx); std::string clearString{' '}; - hal_lcd_display_text(clearString.c_str(), clearString.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(clearString.c_str(), + clearString.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); std::string completeMsg{"Inference complete!"}; /* Display message on the LCD - inference complete. */ - hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(completeMsg.c_str(), + completeMsg.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); return true; } - size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize, - uint8_t* memAddress, size_t memSize){ + size_t DumpDenoisedAudioHeader(const char* filename, + size_t dumpSize, + uint8_t* memAddress, + size_t memSize) + { - if (memAddress == nullptr){ + if (memAddress == nullptr) { return 0; } int32_t filenameLength = strlen(filename); size_t numBytesWritten = 0; size_t numBytesToWrite = 0; - int32_t dumpSizeByte = dumpSize * sizeof(int16_t); - bool overflow = false; + int32_t dumpSizeByte = dumpSize * sizeof(int16_t); + bool overflow = false; /* Write the filename length */ numBytesToWrite = sizeof(filenameLength); @@ -231,7 +259,7 @@ namespace app { /* Write file name */ numBytesToWrite = filenameLength; - if(memSize - numBytesToWrite > 0) { + if (memSize - numBytesToWrite > 0) { std::memcpy(memAddress + numBytesWritten, filename, numBytesToWrite); numBytesWritten += numBytesToWrite; memSize -= numBytesWritten; @@ -241,7 +269,7 @@ namespace app { /* Write dumpSize in byte */ numBytesToWrite = sizeof(dumpSizeByte); - if(memSize - numBytesToWrite > 0) { + if (memSize - numBytesToWrite > 0) { std::memcpy(memAddress + numBytesWritten, &(dumpSizeByte), numBytesToWrite); numBytesWritten += numBytesToWrite; memSize -= numBytesWritten; @@ -249,8 +277,10 @@ namespace app { overflow = true; } - if(false == overflow) { - info("Audio Clip dump header info (%zu bytes) written to %p\n", numBytesWritten, memAddress); + if (false == overflow) { + info("Audio Clip dump header info (%zu bytes) written to %p\n", + numBytesWritten, + memAddress); } else { printf_err("Not enough memory to dump Audio Clip header.\n"); } @@ -258,7 +288,8 @@ namespace app { return numBytesWritten; } - size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){ + size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize) + { if ((memAddress == nullptr) || (memSize < 4)) { return 0; } @@ -266,23 +297,27 @@ namespace app { std::memcpy(memAddress, &eofMarker, sizeof(int32_t)); return sizeof(int32_t); - } + } size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame, - uint8_t* memAddress, size_t memSize) + uint8_t* memAddress, + size_t memSize) { if (memAddress == nullptr) { return 0; } size_t numByteToBeWritten = audioFrame.size() * sizeof(int16_t); - if( numByteToBeWritten > memSize) { - printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n", memSize, numByteToBeWritten, memAddress); + if (numByteToBeWritten > memSize) { + printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n", + memSize, + numByteToBeWritten, + memAddress); numByteToBeWritten = memSize; } std::memcpy(memAddress, audioFrame.data(), numByteToBeWritten); - info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress); + info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress); return numByteToBeWritten; } @@ -290,13 +325,13 @@ namespace app { size_t DumpOutputTensorsToMemory(Model& model, uint8_t* memAddress, const size_t memSize) { const size_t numOutputs = model.GetNumOutputs(); - size_t numBytesWritten = 0; - uint8_t* ptr = memAddress; + size_t numBytesWritten = 0; + uint8_t* ptr = memAddress; /* Iterate over all output tensors. */ for (size_t i = 0; i < numOutputs; ++i) { const TfLiteTensor* tensor = model.GetOutputTensor(i); - const auto* tData = tflite::GetTensorData<uint8_t>(tensor); + const auto* tData = tflite::GetTensorData<uint8_t>(tensor); #if VERIFY_TEST_OUTPUT DumpTensor(tensor); #endif /* VERIFY_TEST_OUTPUT */ @@ -305,15 +340,13 @@ namespace app { if (tensor->bytes > 0) { std::memcpy(ptr, tData, tensor->bytes); - info("Copied %zu bytes for tensor %zu to 0x%p\n", - tensor->bytes, i, ptr); + info("Copied %zu bytes for tensor %zu to 0x%p\n", tensor->bytes, i, ptr); numBytesWritten += tensor->bytes; ptr += tensor->bytes; } } else { - printf_err("Error writing tensor %zu to memory @ 0x%p\n", - i, memAddress); + printf_err("Error writing tensor %zu to memory @ 0x%p\n", i, memAddress); break; } } diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc index 084059e..9330187 100644 --- a/source/use_case/object_detection/src/UseCaseHandler.cc +++ b/source/use_case/object_detection/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,11 +15,11 @@ * limitations under the License. */ #include "UseCaseHandler.hpp" -#include "InputFiles.hpp" -#include "YoloFastestModel.hpp" -#include "UseCaseCommonUtils.hpp" #include "DetectorPostProcessing.hpp" #include "DetectorPreProcessing.hpp" +#include "InputFiles.hpp" +#include "UseCaseCommonUtils.hpp" +#include "YoloFastestModel.hpp" #include "hal.h" #include "log_macros.h" @@ -34,7 +34,8 @@ namespace app { * @param[in] results Vector of detection results to be displayed. * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results); + static bool + PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results); /** * @brief Draw boxes directly on the LCD for all detected objects. @@ -43,11 +44,10 @@ namespace app { * @param[in] imageStartY Y coordinate where the image starts on the LCD. * @param[in] imgDownscaleFactor How much image has been downscaled on LCD. **/ - static void DrawDetectionBoxes( - const std::vector<object_detection::DetectionResult>& results, - uint32_t imgStartX, - uint32_t imgStartY, - uint32_t imgDownscaleFactor); + static void DrawDetectionBoxes(const std::vector<object_detection::DetectionResult>& results, + uint32_t imgStartX, + uint32_t imgStartY, + uint32_t imgDownscaleFactor); /* Object detection inference handler. */ bool ObjectDetectionHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) @@ -55,8 +55,8 @@ namespace app { auto& profiler = ctx.Get<Profiler&>("profiler"); constexpr uint32_t dataPsnImgDownscaleFactor = 1; - constexpr uint32_t dataPsnImgStartX = 10; - constexpr uint32_t dataPsnImgStartY = 35; + constexpr uint32_t dataPsnImgStartX = 10; + constexpr uint32_t dataPsnImgStartY = 35; constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 28; @@ -78,7 +78,7 @@ namespace app { auto initialImgIdx = ctx.Get<uint32_t>("imgIndex"); - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); TfLiteTensor* outputTensor0 = model.GetOutputTensor(0); TfLiteTensor* outputTensor1 = model.GetOutputTensor(1); @@ -99,12 +99,14 @@ namespace app { DetectorPreProcess preProcess = DetectorPreProcess(inputTensor, true, model.IsDataSigned()); std::vector<object_detection::DetectionResult> results; - const object_detection::PostProcessParams postProcessParams { - inputImgRows, inputImgCols, object_detection::originalImageSize, - object_detection::anchor1, object_detection::anchor2 - }; - DetectorPostProcess postProcess = DetectorPostProcess(outputTensor0, outputTensor1, - results, postProcessParams); + const object_detection::PostProcessParams postProcessParams{ + inputImgRows, + inputImgCols, + object_detection::originalImageSize, + object_detection::anchor1, + object_detection::anchor2}; + DetectorPostProcess postProcess = + DetectorPostProcess(outputTensor0, outputTensor1, results, postProcessParams); do { /* Ensure there are no results leftover from previous inference when running all. */ results.clear(); @@ -112,11 +114,11 @@ namespace app { /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; - const uint8_t* currImage = get_img_array(ctx.Get<uint32_t>("imgIndex")); + const uint8_t* currImage = GetImgArray(ctx.Get<uint32_t>("imgIndex")); auto dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8); - const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? - inputTensor->bytes : IMAGE_DATA_SIZE; + const size_t copySz = + inputTensor->bytes < IMAGE_DATA_SIZE ? inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ if (!preProcess.DoPreProcess(currImage, copySz)) { @@ -135,12 +137,13 @@ namespace app { dataPsnImgDownscaleFactor); /* Display message on the LCD - inference running. */ - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); /* Run inference over this image. */ - info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"), - get_filename(ctx.Get<uint32_t>("imgIndex"))); + info("Running inference on image %" PRIu32 " => %s\n", + ctx.Get<uint32_t>("imgIndex"), + GetFilename(ctx.Get<uint32_t>("imgIndex"))); if (!RunInference(model, profiler)) { printf_err("Inference failed."); @@ -154,11 +157,12 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); /* Draw boxes. */ - DrawDetectionBoxes(results, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); + DrawDetectionBoxes( + results, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); #if VERIFY_TEST_OUTPUT DumpTensor(modelOutput0); @@ -171,14 +175,15 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"imgIndex"); + IncrementAppCtxIfmIdx(ctx, "imgIndex"); } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImgIdx); return true; } - static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results) + static bool + PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results) { hal_lcd_set_text_color(COLOR_GREEN); @@ -187,9 +192,14 @@ namespace app { info("Total number of inferences: 1\n"); for (uint32_t i = 0; i < results.size(); ++i) { - info("%" PRIu32 ") (%f) -> %s {x=%d,y=%d,w=%d,h=%d}\n", i, - results[i].m_normalisedVal, "Detection box:", - results[i].m_x0, results[i].m_y0, results[i].m_w, results[i].m_h ); + info("%" PRIu32 ") (%f) -> %s {x=%d,y=%d,w=%d,h=%d}\n", + i, + results[i].m_normalisedVal, + "Detection box:", + results[i].m_x0, + results[i].m_y0, + results[i].m_w, + results[i].m_h); } return true; @@ -202,24 +212,34 @@ namespace app { { uint32_t lineThickness = 1; - for (const auto& result: results) { + for (const auto& result : results) { /* Top line. */ - hal_lcd_display_box(imgStartX + result.m_x0/imgDownscaleFactor, - imgStartY + result.m_y0/imgDownscaleFactor, - result.m_w/imgDownscaleFactor, lineThickness, COLOR_GREEN); + hal_lcd_display_box(imgStartX + result.m_x0 / imgDownscaleFactor, + imgStartY + result.m_y0 / imgDownscaleFactor, + result.m_w / imgDownscaleFactor, + lineThickness, + COLOR_GREEN); /* Bot line. */ - hal_lcd_display_box(imgStartX + result.m_x0/imgDownscaleFactor, - imgStartY + (result.m_y0 + result.m_h)/imgDownscaleFactor - lineThickness, - result.m_w/imgDownscaleFactor, lineThickness, COLOR_GREEN); + hal_lcd_display_box(imgStartX + result.m_x0 / imgDownscaleFactor, + imgStartY + (result.m_y0 + result.m_h) / imgDownscaleFactor - + lineThickness, + result.m_w / imgDownscaleFactor, + lineThickness, + COLOR_GREEN); /* Left line. */ - hal_lcd_display_box(imgStartX + result.m_x0/imgDownscaleFactor, - imgStartY + result.m_y0/imgDownscaleFactor, - lineThickness, result.m_h/imgDownscaleFactor, COLOR_GREEN); + hal_lcd_display_box(imgStartX + result.m_x0 / imgDownscaleFactor, + imgStartY + result.m_y0 / imgDownscaleFactor, + lineThickness, + result.m_h / imgDownscaleFactor, + COLOR_GREEN); /* Right line. */ - hal_lcd_display_box(imgStartX + (result.m_x0 + result.m_w)/imgDownscaleFactor - lineThickness, - imgStartY + result.m_y0/imgDownscaleFactor, - lineThickness, result.m_h/imgDownscaleFactor, COLOR_GREEN); + hal_lcd_display_box(imgStartX + (result.m_x0 + result.m_w) / imgDownscaleFactor - + lineThickness, + imgStartY + result.m_y0 / imgDownscaleFactor, + lineThickness, + result.m_h / imgDownscaleFactor, + COLOR_GREEN); } } diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc index e2e48d1..4704c97 100644 --- a/source/use_case/vww/src/UseCaseHandler.cc +++ b/source/use_case/vww/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,34 +15,34 @@ * limitations under the License. */ #include "UseCaseHandler.hpp" -#include "VisualWakeWordModel.hpp" #include "Classifier.hpp" -#include "InputFiles.hpp" #include "ImageUtils.hpp" +#include "InputFiles.hpp" #include "UseCaseCommonUtils.hpp" +#include "VisualWakeWordModel.hpp" +#include "VisualWakeWordProcessing.hpp" #include "hal.h" #include "log_macros.h" -#include "VisualWakeWordProcessing.hpp" namespace arm { namespace app { /* Visual Wake Word inference handler. */ - bool ClassifyImageHandler(ApplicationContext &ctx, uint32_t imgIndex, bool runAll) + bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& model = ctx.Get<Model&>("model"); + auto& model = ctx.Get<Model&>("model"); /* If the request has a valid size, set the image index. */ if (imgIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, imgIndex,"imgIndex")) { + if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) { return false; } } auto initialImgIdx = ctx.Get<uint32_t>("imgIndex"); constexpr uint32_t dataPsnImgDownscaleFactor = 1; - constexpr uint32_t dataPsnImgStartX = 10; - constexpr uint32_t dataPsnImgStartY = 35; + constexpr uint32_t dataPsnImgStartX = 10; + constexpr uint32_t dataPsnImgStartY = 35; constexpr uint32_t dataPsnTxtInfStartX = 150; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -52,7 +52,7 @@ namespace app { return false; } - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); @@ -66,7 +66,8 @@ namespace app { TfLiteIntArray* inputShape = model.GetInputShape(0); const uint32_t nCols = inputShape->data[arm::app::VisualWakeWordModel::ms_inputColsIdx]; const uint32_t nRows = inputShape->data[arm::app::VisualWakeWordModel::ms_inputRowsIdx]; - if (arm::app::VisualWakeWordModel::ms_inputChannelsIdx >= static_cast<uint32_t>(inputShape->size)) { + if (arm::app::VisualWakeWordModel::ms_inputChannelsIdx >= + static_cast<uint32_t>(inputShape->size)) { printf_err("Invalid channel index.\n"); return false; } @@ -78,9 +79,11 @@ namespace app { VisualWakeWordPreProcess preProcess = VisualWakeWordPreProcess(inputTensor); std::vector<ClassificationResult> results; - VisualWakeWordPostProcess postProcess = VisualWakeWordPostProcess(outputTensor, - ctx.Get<Classifier&>("classifier"), - ctx.Get<std::vector<std::string>&>("labels"), results); + VisualWakeWordPostProcess postProcess = + VisualWakeWordPostProcess(outputTensor, + ctx.Get<Classifier&>("classifier"), + ctx.Get<std::vector<std::string>&>("labels"), + results); do { hal_lcd_clear(COLOR_BLACK); @@ -88,29 +91,34 @@ namespace app { /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; - const uint8_t* imgSrc = get_img_array(ctx.Get<uint32_t>("imgIndex")); + const uint8_t* imgSrc = GetImgArray(ctx.Get<uint32_t>("imgIndex")); if (nullptr == imgSrc) { - printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", ctx.Get<uint32_t>("imgIndex"), + printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", + ctx.Get<uint32_t>("imgIndex"), NUMBER_OF_FILES - 1); return false; } /* Display this image on the LCD. */ - hal_lcd_display_image( - imgSrc, - nCols, nRows, displayChannels, - dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); + hal_lcd_display_image(imgSrc, + nCols, + nRows, + displayChannels, + dataPsnImgStartX, + dataPsnImgStartY, + dataPsnImgDownscaleFactor); /* Display message on the LCD - inference running. */ - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); /* Run inference over this image. */ - info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"), - get_filename(ctx.Get<uint32_t>("imgIndex"))); + info("Running inference on image %" PRIu32 " => %s\n", + ctx.Get<uint32_t>("imgIndex"), + GetFilename(ctx.Get<uint32_t>("imgIndex"))); - const size_t imgSz = inputTensor->bytes < IMAGE_DATA_SIZE ? - inputTensor->bytes : IMAGE_DATA_SIZE; + const size_t imgSz = + inputTensor->bytes < IMAGE_DATA_SIZE ? inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ if (!preProcess.DoPreProcess(imgSrc, imgSz)) { @@ -130,8 +138,8 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text( + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); /* Add results to context for access outside handler. */ ctx.Set<std::vector<ClassificationResult>>("results", results); @@ -146,7 +154,7 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"imgIndex"); + IncrementAppCtxIfmIdx(ctx, "imgIndex"); } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImgIdx); |