diff options
Diffstat (limited to 'source/use_case/noise_reduction/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/noise_reduction/src/UseCaseHandler.cc | 129 |
1 files changed, 44 insertions, 85 deletions
diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc index acb8ba7..53bb43e 100644 --- a/source/use_case/noise_reduction/src/UseCaseHandler.cc +++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc @@ -21,12 +21,10 @@ #include "ImageUtils.hpp" #include "InputFiles.hpp" #include "RNNoiseModel.hpp" -#include "RNNoiseProcess.hpp" +#include "RNNoiseFeatureProcessor.hpp" +#include "RNNoiseProcessing.hpp" #include "log_macros.h" -#include <cmath> -#include <algorithm> - namespace arm { namespace app { @@ -36,17 +34,6 @@ namespace app { **/ static void IncrementAppCtxClipIdx(ApplicationContext& ctx); - /** - * @brief Quantize the given features and populate the input Tensor. - * @param[in] inputFeatures Vector of floating point features to quantize. - * @param[in] quantScale Quantization scale for the inputTensor. - * @param[in] quantOffset Quantization offset for the inputTensor. - * @param[in,out] inputTensor TFLite micro tensor to populate. - **/ - static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, - float quantScale, int quantOffset, - TfLiteTensor* inputTensor); - /* Noise reduction inference handler. */ bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll) { @@ -57,7 +44,7 @@ namespace app { size_t memDumpMaxLen = 0; uint8_t* memDumpBaseAddr = nullptr; size_t undefMemDumpBytesWritten = 0; - size_t *pMemDumpBytesWritten = &undefMemDumpBytesWritten; + size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten; if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) { memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN"); memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR"); @@ -74,8 +61,8 @@ namespace app { } /* Populate Pre-Processing related parameters. */ - auto audioParamsWinLen = ctx.Get<uint32_t>("frameLength"); - auto audioParamsWinStride = ctx.Get<uint32_t>("frameStride"); + auto audioFrameLen = ctx.Get<uint32_t>("frameLength"); + auto audioFrameStride = ctx.Get<uint32_t>("frameStride"); auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures"); TfLiteTensor* inputTensor = model.GetInputTensor(0); @@ -103,7 +90,7 @@ namespace app { if (ctx.Has("featureFileNames")) { audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames"); } - do{ + do { hal_lcd_clear(COLOR_BLACK); auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten; @@ -112,32 +99,38 @@ namespace app { /* Creating a sliding window through the audio. */ auto audioDataSlider = audio::SlidingWindow<const int16_t>( audioAccessorFunc(currentIndex), - audioSizeAccessorFunc(currentIndex), audioParamsWinLen, - audioParamsWinStride); + audioSizeAccessorFunc(currentIndex), audioFrameLen, + audioFrameStride); info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex, audioFileAccessorFunc(currentIndex)); memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex), - (audioDataSlider.TotalStrides() + 1) * audioParamsWinLen, + (audioDataSlider.TotalStrides() + 1) * audioFrameLen, memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); - rnn::RNNoiseProcess featureProcessor = rnn::RNNoiseProcess(); - rnn::vec1D32F audioFrame(audioParamsWinLen); - rnn::vec1D32F inputFeatures(nrNumInputFeatures); - rnn::vec1D32F denoisedAudioFrameFloat(audioParamsWinLen); - std::vector<int16_t> denoisedAudioFrame(audioParamsWinLen); + /* Set up pre and post-processing. */ + std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor = + std::make_shared<rnn::RNNoiseFeatureProcessor>(); + std::shared_ptr<rnn::FrameFeatures> frameFeatures = + std::make_shared<rnn::FrameFeatures>(); + + RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures); + + std::vector<int16_t> denoisedAudioFrame(audioFrameLen); + RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame, + featureProcessor, frameFeatures); - std::vector<float> modelOutputFloat(outputTensor->bytes); - rnn::FrameFeatures frameFeatures; bool resetGRU = true; while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - audioFrame = rnn::vec1D32F(inferenceWindow, inferenceWindow+audioParamsWinLen); - featureProcessor.PreprocessFrame(audioFrame.data(), audioParamsWinLen, frameFeatures); + if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) { + printf_err("Pre-processing failed."); + return false; + } /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */ if (resetGRU){ @@ -148,53 +141,35 @@ namespace app { model.CopyGruStates(); } - QuantizeAndPopulateInput(frameFeatures.m_featuresVec, - inputTensor->params.scale, inputTensor->params.zero_point, - inputTensor); - /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; /* Display message on the LCD - inference running. */ - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run inference over this feature sliding window. */ - profiler.StartProfiling("Inference"); - bool success = model.RunInference(); - profiler.StopProfiling(); - resetGRU = false; - - if (!success) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } + resetGRU = false; - /* De-quantize main model output ready for post-processing. */ - const auto* outputData = tflite::GetTensorData<int8_t>(outputTensor); - auto outputQuantParams = arm::app::GetTensorQuantParams(outputTensor); - - for (size_t i = 0; i < outputTensor->bytes; ++i) { - modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset) - * outputQuantParams.scale; - } - - /* Round and cast the post-processed results for dumping to wav. */ - featureProcessor.PostProcessFrame(modelOutputFloat, frameFeatures, denoisedAudioFrameFloat); - for (size_t i = 0; i < audioParamsWinLen; ++i) { - denoisedAudioFrame[i] = static_cast<int16_t>(std::roundf(denoisedAudioFrameFloat[i])); + /* Carry out post-processing. */ + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); + return false; } /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); if (memDumpMaxLen > 0) { - /* Dump output tensors to memory. */ + /* Dump final post processed output to memory. */ memDumpBytesWritten += DumpOutputDenoisedAudioFrame( denoisedAudioFrame, memDumpBaseAddr + memDumpBytesWritten, @@ -209,6 +184,7 @@ namespace app { valMemDumpBytesWritten, startDumpAddress); } + /* Finish by dumping the footer. */ DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); info("All inferences for audio clip complete.\n"); @@ -216,15 +192,13 @@ namespace app { IncrementAppCtxClipIdx(ctx); std::string clearString{' '}; - hal_lcd_display_text( - clearString.c_str(), clearString.size(), + hal_lcd_display_text(clearString.c_str(), clearString.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); std::string completeMsg{"Inference complete!"}; /* Display message on the LCD - inference complete. */ - hal_lcd_display_text( - completeMsg.c_str(), completeMsg.size(), + hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); @@ -233,7 +207,7 @@ namespace app { } size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize, - uint8_t *memAddress, size_t memSize){ + uint8_t* memAddress, size_t memSize){ if (memAddress == nullptr){ return 0; @@ -284,7 +258,7 @@ namespace app { return numBytesWritten; } - size_t DumpDenoisedAudioFooter(uint8_t *memAddress, size_t memSize){ + size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){ if ((memAddress == nullptr) || (memSize < 4)) { return 0; } @@ -294,8 +268,8 @@ namespace app { return sizeof(int32_t); } - size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t> &audioFrame, - uint8_t *memAddress, size_t memSize) + size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame, + uint8_t* memAddress, size_t memSize) { if (memAddress == nullptr) { return 0; @@ -324,7 +298,7 @@ namespace app { const TfLiteTensor* tensor = model.GetOutputTensor(i); const auto* tData = tflite::GetTensorData<uint8_t>(tensor); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(tensor); + DumpTensor(tensor); #endif /* VERIFY_TEST_OUTPUT */ /* Ensure that we don't overflow the allowed limit. */ if (numBytesWritten + tensor->bytes <= memSize) { @@ -360,20 +334,5 @@ namespace app { ctx.Set<uint32_t>("clipIndex", curClipIdx); } - void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, - const float quantScale, const int quantOffset, TfLiteTensor* inputTensor) - { - const float minVal = std::numeric_limits<int8_t>::min(); - const float maxVal = std::numeric_limits<int8_t>::max(); - - auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor); - - for (size_t i=0; i < inputFeatures.size(); ++i) { - float quantValue = ((inputFeatures[i] / quantScale) + quantOffset); - inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal)); - } - } - - } /* namespace app */ } /* namespace arm */ |