summaryrefslogtreecommitdiff
path: root/source/use_case/noise_reduction/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/noise_reduction/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/noise_reduction/src/UseCaseHandler.cc191
1 files changed, 112 insertions, 79 deletions
diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc
index 0aef600..0c5ff39 100644
--- a/source/use_case/noise_reduction/src/UseCaseHandler.cc
+++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc
@@ -1,6 +1,6 @@
/*
- * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
- * SPDX-License-Identifier: Apache-2.0
+ * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
+ * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,24 +14,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "hal.h"
#include "UseCaseHandler.hpp"
-#include "UseCaseCommonUtils.hpp"
#include "AudioUtils.hpp"
#include "ImageUtils.hpp"
#include "InputFiles.hpp"
-#include "RNNoiseModel.hpp"
#include "RNNoiseFeatureProcessor.hpp"
+#include "RNNoiseModel.hpp"
#include "RNNoiseProcessing.hpp"
+#include "UseCaseCommonUtils.hpp"
+#include "hal.h"
#include "log_macros.h"
namespace arm {
namespace app {
/**
- * @brief Helper function to increment current audio clip features index.
- * @param[in,out] ctx Pointer to the application context object.
- **/
+ * @brief Helper function to increment current audio clip features index.
+ * @param[in,out] ctx Pointer to the application context object.
+ **/
static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
/* Noise reduction inference handler. */
@@ -41,17 +41,18 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartY = 40;
/* Variables used for memory dumping. */
- size_t memDumpMaxLen = 0;
- uint8_t* memDumpBaseAddr = nullptr;
+ size_t memDumpMaxLen = 0;
+ uint8_t* memDumpBaseAddr = nullptr;
size_t undefMemDumpBytesWritten = 0;
- size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten;
- if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
- memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
- memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
+ size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten;
+ if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") &&
+ ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
+ memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
+ memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
pMemDumpBytesWritten = ctx.Get<size_t*>("MEM_DUMP_BYTE_WRITTEN");
}
std::reference_wrapper<size_t> memDumpBytesWritten = std::ref(*pMemDumpBytesWritten);
- auto& profiler = ctx.Get<Profiler&>("profiler");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
/* Get model reference. */
auto& model = ctx.Get<RNNoiseModel&>("model");
@@ -61,15 +62,16 @@ namespace app {
}
/* Populate Pre-Processing related parameters. */
- auto audioFrameLen = ctx.Get<uint32_t>("frameLength");
- auto audioFrameStride = ctx.Get<uint32_t>("frameStride");
+ auto audioFrameLen = ctx.Get<uint32_t>("frameLength");
+ auto audioFrameStride = ctx.Get<uint32_t>("frameStride");
auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures");
TfLiteTensor* inputTensor = model.GetInputTensor(0);
if (nrNumInputFeatures != inputTensor->bytes) {
printf_err("Input features size must be equal to input tensor size."
" Feature size = %" PRIu32 ", Tensor size = %zu.\n",
- nrNumInputFeatures, inputTensor->bytes);
+ nrNumInputFeatures,
+ inputTensor->bytes);
return false;
}
@@ -78,49 +80,55 @@ namespace app {
/* Initial choice of index for WAV file. */
auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
- std::function<const int16_t* (const uint32_t)> audioAccessorFunc = get_audio_array;
+ std::function<const int16_t*(const uint32_t)> audioAccessorFunc = GetAudioArray;
if (ctx.Has("features")) {
- audioAccessorFunc = ctx.Get<std::function<const int16_t* (const uint32_t)>>("features");
+ audioAccessorFunc = ctx.Get<std::function<const int16_t*(const uint32_t)>>("features");
}
- std::function<uint32_t (const uint32_t)> audioSizeAccessorFunc = get_audio_array_size;
+ std::function<uint32_t(const uint32_t)> audioSizeAccessorFunc = GetAudioArraySize;
if (ctx.Has("featureSizes")) {
- audioSizeAccessorFunc = ctx.Get<std::function<uint32_t (const uint32_t)>>("featureSizes");
+ audioSizeAccessorFunc =
+ ctx.Get<std::function<uint32_t(const uint32_t)>>("featureSizes");
}
- std::function<const char*(const uint32_t)> audioFileAccessorFunc = get_filename;
+ std::function<const char*(const uint32_t)> audioFileAccessorFunc = GetFilename;
if (ctx.Has("featureFileNames")) {
- audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
+ audioFileAccessorFunc =
+ ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
}
do {
hal_lcd_clear(COLOR_BLACK);
auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten;
- auto currentIndex = ctx.Get<uint32_t>("clipIndex");
+ auto currentIndex = ctx.Get<uint32_t>("clipIndex");
/* Creating a sliding window through the audio. */
- auto audioDataSlider = audio::SlidingWindow<const int16_t>(
- audioAccessorFunc(currentIndex),
- audioSizeAccessorFunc(currentIndex), audioFrameLen,
- audioFrameStride);
-
- info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex,
+ auto audioDataSlider =
+ audio::SlidingWindow<const int16_t>(audioAccessorFunc(currentIndex),
+ audioSizeAccessorFunc(currentIndex),
+ audioFrameLen,
+ audioFrameStride);
+
+ info("Running inference on input feature map %" PRIu32 " => %s\n",
+ currentIndex,
audioFileAccessorFunc(currentIndex));
- memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
- (audioDataSlider.TotalStrides() + 1) * audioFrameLen,
- memDumpBaseAddr + memDumpBytesWritten,
- memDumpMaxLen - memDumpBytesWritten);
+ memDumpBytesWritten +=
+ DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
+ (audioDataSlider.TotalStrides() + 1) * audioFrameLen,
+ memDumpBaseAddr + memDumpBytesWritten,
+ memDumpMaxLen - memDumpBytesWritten);
/* Set up pre and post-processing. */
std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor =
- std::make_shared<rnn::RNNoiseFeatureProcessor>();
+ std::make_shared<rnn::RNNoiseFeatureProcessor>();
std::shared_ptr<rnn::FrameFeatures> frameFeatures =
- std::make_shared<rnn::FrameFeatures>();
+ std::make_shared<rnn::FrameFeatures>();
- RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures);
+ RNNoisePreProcess preProcess =
+ RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures);
std::vector<int16_t> denoisedAudioFrame(audioFrameLen);
- RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame,
- featureProcessor, frameFeatures);
+ RNNoisePostProcess postProcess = RNNoisePostProcess(
+ outputTensor, denoisedAudioFrame, featureProcessor, frameFeatures);
bool resetGRU = true;
@@ -133,11 +141,12 @@ namespace app {
}
/* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */
- if (resetGRU){
+ if (resetGRU) {
model.ResetGruState();
} else {
/* Copying gru state outputs to gru state inputs.
- * Call ResetGruState in between the sequence of inferences on unrelated input data. */
+ * Call ResetGruState in between the sequence of inferences on unrelated input
+ * data. */
model.CopyGruStates();
}
@@ -145,10 +154,15 @@ namespace app {
std::string str_inf{"Running inference... "};
/* Display message on the LCD - inference running. */
- hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(),
+ str_inf.size(),
+ dataPsnTxtInfStartX,
+ dataPsnTxtInfStartY,
+ false);
- info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1);
+ info("Inference %zu/%zu\n",
+ audioDataSlider.Index() + 1,
+ audioDataSlider.TotalStrides() + 1);
/* Run inference over this feature sliding window. */
if (!RunInference(model, profiler)) {
@@ -165,15 +179,18 @@ namespace app {
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(),
+ str_inf.size(),
+ dataPsnTxtInfStartX,
+ dataPsnTxtInfStartY,
+ false);
if (memDumpMaxLen > 0) {
/* Dump final post processed output to memory. */
- memDumpBytesWritten += DumpOutputDenoisedAudioFrame(
- denoisedAudioFrame,
- memDumpBaseAddr + memDumpBytesWritten,
- memDumpMaxLen - memDumpBytesWritten);
+ memDumpBytesWritten +=
+ DumpOutputDenoisedAudioFrame(denoisedAudioFrame,
+ memDumpBaseAddr + memDumpBytesWritten,
+ memDumpMaxLen - memDumpBytesWritten);
}
}
@@ -181,43 +198,54 @@ namespace app {
/* Needed to not let the compiler complain about type mismatch. */
size_t valMemDumpBytesWritten = memDumpBytesWritten;
info("Output memory dump of %zu bytes written at address 0x%p\n",
- valMemDumpBytesWritten, startDumpAddress);
+ valMemDumpBytesWritten,
+ startDumpAddress);
}
/* Finish by dumping the footer. */
- DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten);
+ DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten,
+ memDumpMaxLen - memDumpBytesWritten);
info("All inferences for audio clip complete.\n");
profiler.PrintProfilingResult();
IncrementAppCtxClipIdx(ctx);
std::string clearString{' '};
- hal_lcd_display_text(clearString.c_str(), clearString.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(clearString.c_str(),
+ clearString.size(),
+ dataPsnTxtInfStartX,
+ dataPsnTxtInfStartY,
+ false);
std::string completeMsg{"Inference complete!"};
/* Display message on the LCD - inference complete. */
- hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(completeMsg.c_str(),
+ completeMsg.size(),
+ dataPsnTxtInfStartX,
+ dataPsnTxtInfStartY,
+ false);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
return true;
}
- size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize,
- uint8_t* memAddress, size_t memSize){
+ size_t DumpDenoisedAudioHeader(const char* filename,
+ size_t dumpSize,
+ uint8_t* memAddress,
+ size_t memSize)
+ {
- if (memAddress == nullptr){
+ if (memAddress == nullptr) {
return 0;
}
int32_t filenameLength = strlen(filename);
size_t numBytesWritten = 0;
size_t numBytesToWrite = 0;
- int32_t dumpSizeByte = dumpSize * sizeof(int16_t);
- bool overflow = false;
+ int32_t dumpSizeByte = dumpSize * sizeof(int16_t);
+ bool overflow = false;
/* Write the filename length */
numBytesToWrite = sizeof(filenameLength);
@@ -231,7 +259,7 @@ namespace app {
/* Write file name */
numBytesToWrite = filenameLength;
- if(memSize - numBytesToWrite > 0) {
+ if (memSize - numBytesToWrite > 0) {
std::memcpy(memAddress + numBytesWritten, filename, numBytesToWrite);
numBytesWritten += numBytesToWrite;
memSize -= numBytesWritten;
@@ -241,7 +269,7 @@ namespace app {
/* Write dumpSize in byte */
numBytesToWrite = sizeof(dumpSizeByte);
- if(memSize - numBytesToWrite > 0) {
+ if (memSize - numBytesToWrite > 0) {
std::memcpy(memAddress + numBytesWritten, &(dumpSizeByte), numBytesToWrite);
numBytesWritten += numBytesToWrite;
memSize -= numBytesWritten;
@@ -249,8 +277,10 @@ namespace app {
overflow = true;
}
- if(false == overflow) {
- info("Audio Clip dump header info (%zu bytes) written to %p\n", numBytesWritten, memAddress);
+ if (false == overflow) {
+ info("Audio Clip dump header info (%zu bytes) written to %p\n",
+ numBytesWritten,
+ memAddress);
} else {
printf_err("Not enough memory to dump Audio Clip header.\n");
}
@@ -258,7 +288,8 @@ namespace app {
return numBytesWritten;
}
- size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){
+ size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize)
+ {
if ((memAddress == nullptr) || (memSize < 4)) {
return 0;
}
@@ -266,23 +297,27 @@ namespace app {
std::memcpy(memAddress, &eofMarker, sizeof(int32_t));
return sizeof(int32_t);
- }
+ }
size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame,
- uint8_t* memAddress, size_t memSize)
+ uint8_t* memAddress,
+ size_t memSize)
{
if (memAddress == nullptr) {
return 0;
}
size_t numByteToBeWritten = audioFrame.size() * sizeof(int16_t);
- if( numByteToBeWritten > memSize) {
- printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n", memSize, numByteToBeWritten, memAddress);
+ if (numByteToBeWritten > memSize) {
+ printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n",
+ memSize,
+ numByteToBeWritten,
+ memAddress);
numByteToBeWritten = memSize;
}
std::memcpy(memAddress, audioFrame.data(), numByteToBeWritten);
- info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress);
+ info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress);
return numByteToBeWritten;
}
@@ -290,13 +325,13 @@ namespace app {
size_t DumpOutputTensorsToMemory(Model& model, uint8_t* memAddress, const size_t memSize)
{
const size_t numOutputs = model.GetNumOutputs();
- size_t numBytesWritten = 0;
- uint8_t* ptr = memAddress;
+ size_t numBytesWritten = 0;
+ uint8_t* ptr = memAddress;
/* Iterate over all output tensors. */
for (size_t i = 0; i < numOutputs; ++i) {
const TfLiteTensor* tensor = model.GetOutputTensor(i);
- const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
+ const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
#if VERIFY_TEST_OUTPUT
DumpTensor(tensor);
#endif /* VERIFY_TEST_OUTPUT */
@@ -305,15 +340,13 @@ namespace app {
if (tensor->bytes > 0) {
std::memcpy(ptr, tData, tensor->bytes);
- info("Copied %zu bytes for tensor %zu to 0x%p\n",
- tensor->bytes, i, ptr);
+ info("Copied %zu bytes for tensor %zu to 0x%p\n", tensor->bytes, i, ptr);
numBytesWritten += tensor->bytes;
ptr += tensor->bytes;
}
} else {
- printf_err("Error writing tensor %zu to memory @ 0x%p\n",
- i, memAddress);
+ printf_err("Error writing tensor %zu to memory @ 0x%p\n", i, memAddress);
break;
}
}