summaryrefslogtreecommitdiff
path: root/source/use_case/noise_reduction/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/noise_reduction/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/noise_reduction/src/UseCaseHandler.cc129
1 files changed, 44 insertions, 85 deletions
diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc
index acb8ba7..53bb43e 100644
--- a/source/use_case/noise_reduction/src/UseCaseHandler.cc
+++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc
@@ -21,12 +21,10 @@
#include "ImageUtils.hpp"
#include "InputFiles.hpp"
#include "RNNoiseModel.hpp"
-#include "RNNoiseProcess.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
+#include "RNNoiseProcessing.hpp"
#include "log_macros.h"
-#include <cmath>
-#include <algorithm>
-
namespace arm {
namespace app {
@@ -36,17 +34,6 @@ namespace app {
**/
static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
- /**
- * @brief Quantize the given features and populate the input Tensor.
- * @param[in] inputFeatures Vector of floating point features to quantize.
- * @param[in] quantScale Quantization scale for the inputTensor.
- * @param[in] quantOffset Quantization offset for the inputTensor.
- * @param[in,out] inputTensor TFLite micro tensor to populate.
- **/
- static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
- float quantScale, int quantOffset,
- TfLiteTensor* inputTensor);
-
/* Noise reduction inference handler. */
bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll)
{
@@ -57,7 +44,7 @@ namespace app {
size_t memDumpMaxLen = 0;
uint8_t* memDumpBaseAddr = nullptr;
size_t undefMemDumpBytesWritten = 0;
- size_t *pMemDumpBytesWritten = &undefMemDumpBytesWritten;
+ size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten;
if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
@@ -74,8 +61,8 @@ namespace app {
}
/* Populate Pre-Processing related parameters. */
- auto audioParamsWinLen = ctx.Get<uint32_t>("frameLength");
- auto audioParamsWinStride = ctx.Get<uint32_t>("frameStride");
+ auto audioFrameLen = ctx.Get<uint32_t>("frameLength");
+ auto audioFrameStride = ctx.Get<uint32_t>("frameStride");
auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures");
TfLiteTensor* inputTensor = model.GetInputTensor(0);
@@ -103,7 +90,7 @@ namespace app {
if (ctx.Has("featureFileNames")) {
audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
}
- do{
+ do {
hal_lcd_clear(COLOR_BLACK);
auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten;
@@ -112,32 +99,38 @@ namespace app {
/* Creating a sliding window through the audio. */
auto audioDataSlider = audio::SlidingWindow<const int16_t>(
audioAccessorFunc(currentIndex),
- audioSizeAccessorFunc(currentIndex), audioParamsWinLen,
- audioParamsWinStride);
+ audioSizeAccessorFunc(currentIndex), audioFrameLen,
+ audioFrameStride);
info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex,
audioFileAccessorFunc(currentIndex));
memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
- (audioDataSlider.TotalStrides() + 1) * audioParamsWinLen,
+ (audioDataSlider.TotalStrides() + 1) * audioFrameLen,
memDumpBaseAddr + memDumpBytesWritten,
memDumpMaxLen - memDumpBytesWritten);
- rnn::RNNoiseProcess featureProcessor = rnn::RNNoiseProcess();
- rnn::vec1D32F audioFrame(audioParamsWinLen);
- rnn::vec1D32F inputFeatures(nrNumInputFeatures);
- rnn::vec1D32F denoisedAudioFrameFloat(audioParamsWinLen);
- std::vector<int16_t> denoisedAudioFrame(audioParamsWinLen);
+ /* Set up pre and post-processing. */
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor =
+ std::make_shared<rnn::RNNoiseFeatureProcessor>();
+ std::shared_ptr<rnn::FrameFeatures> frameFeatures =
+ std::make_shared<rnn::FrameFeatures>();
+
+ RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures);
+
+ std::vector<int16_t> denoisedAudioFrame(audioFrameLen);
+ RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame,
+ featureProcessor, frameFeatures);
- std::vector<float> modelOutputFloat(outputTensor->bytes);
- rnn::FrameFeatures frameFeatures;
bool resetGRU = true;
while (audioDataSlider.HasNext()) {
const int16_t* inferenceWindow = audioDataSlider.Next();
- audioFrame = rnn::vec1D32F(inferenceWindow, inferenceWindow+audioParamsWinLen);
- featureProcessor.PreprocessFrame(audioFrame.data(), audioParamsWinLen, frameFeatures);
+ if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) {
+ printf_err("Pre-processing failed.");
+ return false;
+ }
/* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */
if (resetGRU){
@@ -148,53 +141,35 @@ namespace app {
model.CopyGruStates();
}
- QuantizeAndPopulateInput(frameFeatures.m_featuresVec,
- inputTensor->params.scale, inputTensor->params.zero_point,
- inputTensor);
-
/* Strings for presentation/logging. */
std::string str_inf{"Running inference... "};
/* Display message on the LCD - inference running. */
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1);
/* Run inference over this feature sliding window. */
- profiler.StartProfiling("Inference");
- bool success = model.RunInference();
- profiler.StopProfiling();
- resetGRU = false;
-
- if (!success) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
+ resetGRU = false;
- /* De-quantize main model output ready for post-processing. */
- const auto* outputData = tflite::GetTensorData<int8_t>(outputTensor);
- auto outputQuantParams = arm::app::GetTensorQuantParams(outputTensor);
-
- for (size_t i = 0; i < outputTensor->bytes; ++i) {
- modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset)
- * outputQuantParams.scale;
- }
-
- /* Round and cast the post-processed results for dumping to wav. */
- featureProcessor.PostProcessFrame(modelOutputFloat, frameFeatures, denoisedAudioFrameFloat);
- for (size_t i = 0; i < audioParamsWinLen; ++i) {
- denoisedAudioFrame[i] = static_cast<int16_t>(std::roundf(denoisedAudioFrameFloat[i]));
+ /* Carry out post-processing. */
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
+ return false;
}
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
if (memDumpMaxLen > 0) {
- /* Dump output tensors to memory. */
+ /* Dump final post processed output to memory. */
memDumpBytesWritten += DumpOutputDenoisedAudioFrame(
denoisedAudioFrame,
memDumpBaseAddr + memDumpBytesWritten,
@@ -209,6 +184,7 @@ namespace app {
valMemDumpBytesWritten, startDumpAddress);
}
+ /* Finish by dumping the footer. */
DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten);
info("All inferences for audio clip complete.\n");
@@ -216,15 +192,13 @@ namespace app {
IncrementAppCtxClipIdx(ctx);
std::string clearString{' '};
- hal_lcd_display_text(
- clearString.c_str(), clearString.size(),
+ hal_lcd_display_text(clearString.c_str(), clearString.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
std::string completeMsg{"Inference complete!"};
/* Display message on the LCD - inference complete. */
- hal_lcd_display_text(
- completeMsg.c_str(), completeMsg.size(),
+ hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
@@ -233,7 +207,7 @@ namespace app {
}
size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize,
- uint8_t *memAddress, size_t memSize){
+ uint8_t* memAddress, size_t memSize){
if (memAddress == nullptr){
return 0;
@@ -284,7 +258,7 @@ namespace app {
return numBytesWritten;
}
- size_t DumpDenoisedAudioFooter(uint8_t *memAddress, size_t memSize){
+ size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){
if ((memAddress == nullptr) || (memSize < 4)) {
return 0;
}
@@ -294,8 +268,8 @@ namespace app {
return sizeof(int32_t);
}
- size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t> &audioFrame,
- uint8_t *memAddress, size_t memSize)
+ size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame,
+ uint8_t* memAddress, size_t memSize)
{
if (memAddress == nullptr) {
return 0;
@@ -324,7 +298,7 @@ namespace app {
const TfLiteTensor* tensor = model.GetOutputTensor(i);
const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(tensor);
+ DumpTensor(tensor);
#endif /* VERIFY_TEST_OUTPUT */
/* Ensure that we don't overflow the allowed limit. */
if (numBytesWritten + tensor->bytes <= memSize) {
@@ -360,20 +334,5 @@ namespace app {
ctx.Set<uint32_t>("clipIndex", curClipIdx);
}
- void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
- const float quantScale, const int quantOffset, TfLiteTensor* inputTensor)
- {
- const float minVal = std::numeric_limits<int8_t>::min();
- const float maxVal = std::numeric_limits<int8_t>::max();
-
- auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor);
-
- for (size_t i=0; i < inputFeatures.size(); ++i) {
- float quantValue = ((inputFeatures[i] / quantScale) + quantOffset);
- inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal));
- }
- }
-
-
} /* namespace app */
} /* namespace arm */