diff options
Diffstat (limited to 'source/use_case/noise_reduction/src')
-rw-r--r-- | source/use_case/noise_reduction/src/UseCaseHandler.cc | 191 |
1 files changed, 112 insertions, 79 deletions
diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc index 0aef600..0c5ff39 100644 --- a/source/use_case/noise_reduction/src/UseCaseHandler.cc +++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,24 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" #include "UseCaseHandler.hpp" -#include "UseCaseCommonUtils.hpp" #include "AudioUtils.hpp" #include "ImageUtils.hpp" #include "InputFiles.hpp" -#include "RNNoiseModel.hpp" #include "RNNoiseFeatureProcessor.hpp" +#include "RNNoiseModel.hpp" #include "RNNoiseProcessing.hpp" +#include "UseCaseCommonUtils.hpp" +#include "hal.h" #include "log_macros.h" namespace arm { namespace app { /** - * @brief Helper function to increment current audio clip features index. - * @param[in,out] ctx Pointer to the application context object. - **/ + * @brief Helper function to increment current audio clip features index. + * @param[in,out] ctx Pointer to the application context object. + **/ static void IncrementAppCtxClipIdx(ApplicationContext& ctx); /* Noise reduction inference handler. */ @@ -41,17 +41,18 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartY = 40; /* Variables used for memory dumping. */ - size_t memDumpMaxLen = 0; - uint8_t* memDumpBaseAddr = nullptr; + size_t memDumpMaxLen = 0; + uint8_t* memDumpBaseAddr = nullptr; size_t undefMemDumpBytesWritten = 0; - size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten; - if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) { - memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN"); - memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR"); + size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten; + if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && + ctx.Has("MEM_DUMP_BYTE_WRITTEN")) { + memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN"); + memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR"); pMemDumpBytesWritten = ctx.Get<size_t*>("MEM_DUMP_BYTE_WRITTEN"); } std::reference_wrapper<size_t> memDumpBytesWritten = std::ref(*pMemDumpBytesWritten); - auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& profiler = ctx.Get<Profiler&>("profiler"); /* Get model reference. */ auto& model = ctx.Get<RNNoiseModel&>("model"); @@ -61,15 +62,16 @@ namespace app { } /* Populate Pre-Processing related parameters. */ - auto audioFrameLen = ctx.Get<uint32_t>("frameLength"); - auto audioFrameStride = ctx.Get<uint32_t>("frameStride"); + auto audioFrameLen = ctx.Get<uint32_t>("frameLength"); + auto audioFrameStride = ctx.Get<uint32_t>("frameStride"); auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures"); TfLiteTensor* inputTensor = model.GetInputTensor(0); if (nrNumInputFeatures != inputTensor->bytes) { printf_err("Input features size must be equal to input tensor size." " Feature size = %" PRIu32 ", Tensor size = %zu.\n", - nrNumInputFeatures, inputTensor->bytes); + nrNumInputFeatures, + inputTensor->bytes); return false; } @@ -78,49 +80,55 @@ namespace app { /* Initial choice of index for WAV file. */ auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); - std::function<const int16_t* (const uint32_t)> audioAccessorFunc = get_audio_array; + std::function<const int16_t*(const uint32_t)> audioAccessorFunc = GetAudioArray; if (ctx.Has("features")) { - audioAccessorFunc = ctx.Get<std::function<const int16_t* (const uint32_t)>>("features"); + audioAccessorFunc = ctx.Get<std::function<const int16_t*(const uint32_t)>>("features"); } - std::function<uint32_t (const uint32_t)> audioSizeAccessorFunc = get_audio_array_size; + std::function<uint32_t(const uint32_t)> audioSizeAccessorFunc = GetAudioArraySize; if (ctx.Has("featureSizes")) { - audioSizeAccessorFunc = ctx.Get<std::function<uint32_t (const uint32_t)>>("featureSizes"); + audioSizeAccessorFunc = + ctx.Get<std::function<uint32_t(const uint32_t)>>("featureSizes"); } - std::function<const char*(const uint32_t)> audioFileAccessorFunc = get_filename; + std::function<const char*(const uint32_t)> audioFileAccessorFunc = GetFilename; if (ctx.Has("featureFileNames")) { - audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames"); + audioFileAccessorFunc = + ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames"); } do { hal_lcd_clear(COLOR_BLACK); auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten; - auto currentIndex = ctx.Get<uint32_t>("clipIndex"); + auto currentIndex = ctx.Get<uint32_t>("clipIndex"); /* Creating a sliding window through the audio. */ - auto audioDataSlider = audio::SlidingWindow<const int16_t>( - audioAccessorFunc(currentIndex), - audioSizeAccessorFunc(currentIndex), audioFrameLen, - audioFrameStride); - - info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex, + auto audioDataSlider = + audio::SlidingWindow<const int16_t>(audioAccessorFunc(currentIndex), + audioSizeAccessorFunc(currentIndex), + audioFrameLen, + audioFrameStride); + + info("Running inference on input feature map %" PRIu32 " => %s\n", + currentIndex, audioFileAccessorFunc(currentIndex)); - memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex), - (audioDataSlider.TotalStrides() + 1) * audioFrameLen, - memDumpBaseAddr + memDumpBytesWritten, - memDumpMaxLen - memDumpBytesWritten); + memDumpBytesWritten += + DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex), + (audioDataSlider.TotalStrides() + 1) * audioFrameLen, + memDumpBaseAddr + memDumpBytesWritten, + memDumpMaxLen - memDumpBytesWritten); /* Set up pre and post-processing. */ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor = - std::make_shared<rnn::RNNoiseFeatureProcessor>(); + std::make_shared<rnn::RNNoiseFeatureProcessor>(); std::shared_ptr<rnn::FrameFeatures> frameFeatures = - std::make_shared<rnn::FrameFeatures>(); + std::make_shared<rnn::FrameFeatures>(); - RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures); + RNNoisePreProcess preProcess = + RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures); std::vector<int16_t> denoisedAudioFrame(audioFrameLen); - RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame, - featureProcessor, frameFeatures); + RNNoisePostProcess postProcess = RNNoisePostProcess( + outputTensor, denoisedAudioFrame, featureProcessor, frameFeatures); bool resetGRU = true; @@ -133,11 +141,12 @@ namespace app { } /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */ - if (resetGRU){ + if (resetGRU) { model.ResetGruState(); } else { /* Copying gru state outputs to gru state inputs. - * Call ResetGruState in between the sequence of inferences on unrelated input data. */ + * Call ResetGruState in between the sequence of inferences on unrelated input + * data. */ model.CopyGruStates(); } @@ -145,10 +154,15 @@ namespace app { std::string str_inf{"Running inference... "}; /* Display message on the LCD - inference running. */ - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), + str_inf.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); /* Run inference over this feature sliding window. */ if (!RunInference(model, profiler)) { @@ -165,15 +179,18 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), + str_inf.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); if (memDumpMaxLen > 0) { /* Dump final post processed output to memory. */ - memDumpBytesWritten += DumpOutputDenoisedAudioFrame( - denoisedAudioFrame, - memDumpBaseAddr + memDumpBytesWritten, - memDumpMaxLen - memDumpBytesWritten); + memDumpBytesWritten += + DumpOutputDenoisedAudioFrame(denoisedAudioFrame, + memDumpBaseAddr + memDumpBytesWritten, + memDumpMaxLen - memDumpBytesWritten); } } @@ -181,43 +198,54 @@ namespace app { /* Needed to not let the compiler complain about type mismatch. */ size_t valMemDumpBytesWritten = memDumpBytesWritten; info("Output memory dump of %zu bytes written at address 0x%p\n", - valMemDumpBytesWritten, startDumpAddress); + valMemDumpBytesWritten, + startDumpAddress); } /* Finish by dumping the footer. */ - DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); + DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, + memDumpMaxLen - memDumpBytesWritten); info("All inferences for audio clip complete.\n"); profiler.PrintProfilingResult(); IncrementAppCtxClipIdx(ctx); std::string clearString{' '}; - hal_lcd_display_text(clearString.c_str(), clearString.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(clearString.c_str(), + clearString.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); std::string completeMsg{"Inference complete!"}; /* Display message on the LCD - inference complete. */ - hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(completeMsg.c_str(), + completeMsg.size(), + dataPsnTxtInfStartX, + dataPsnTxtInfStartY, + false); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); return true; } - size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize, - uint8_t* memAddress, size_t memSize){ + size_t DumpDenoisedAudioHeader(const char* filename, + size_t dumpSize, + uint8_t* memAddress, + size_t memSize) + { - if (memAddress == nullptr){ + if (memAddress == nullptr) { return 0; } int32_t filenameLength = strlen(filename); size_t numBytesWritten = 0; size_t numBytesToWrite = 0; - int32_t dumpSizeByte = dumpSize * sizeof(int16_t); - bool overflow = false; + int32_t dumpSizeByte = dumpSize * sizeof(int16_t); + bool overflow = false; /* Write the filename length */ numBytesToWrite = sizeof(filenameLength); @@ -231,7 +259,7 @@ namespace app { /* Write file name */ numBytesToWrite = filenameLength; - if(memSize - numBytesToWrite > 0) { + if (memSize - numBytesToWrite > 0) { std::memcpy(memAddress + numBytesWritten, filename, numBytesToWrite); numBytesWritten += numBytesToWrite; memSize -= numBytesWritten; @@ -241,7 +269,7 @@ namespace app { /* Write dumpSize in byte */ numBytesToWrite = sizeof(dumpSizeByte); - if(memSize - numBytesToWrite > 0) { + if (memSize - numBytesToWrite > 0) { std::memcpy(memAddress + numBytesWritten, &(dumpSizeByte), numBytesToWrite); numBytesWritten += numBytesToWrite; memSize -= numBytesWritten; @@ -249,8 +277,10 @@ namespace app { overflow = true; } - if(false == overflow) { - info("Audio Clip dump header info (%zu bytes) written to %p\n", numBytesWritten, memAddress); + if (false == overflow) { + info("Audio Clip dump header info (%zu bytes) written to %p\n", + numBytesWritten, + memAddress); } else { printf_err("Not enough memory to dump Audio Clip header.\n"); } @@ -258,7 +288,8 @@ namespace app { return numBytesWritten; } - size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){ + size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize) + { if ((memAddress == nullptr) || (memSize < 4)) { return 0; } @@ -266,23 +297,27 @@ namespace app { std::memcpy(memAddress, &eofMarker, sizeof(int32_t)); return sizeof(int32_t); - } + } size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame, - uint8_t* memAddress, size_t memSize) + uint8_t* memAddress, + size_t memSize) { if (memAddress == nullptr) { return 0; } size_t numByteToBeWritten = audioFrame.size() * sizeof(int16_t); - if( numByteToBeWritten > memSize) { - printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n", memSize, numByteToBeWritten, memAddress); + if (numByteToBeWritten > memSize) { + printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n", + memSize, + numByteToBeWritten, + memAddress); numByteToBeWritten = memSize; } std::memcpy(memAddress, audioFrame.data(), numByteToBeWritten); - info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress); + info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress); return numByteToBeWritten; } @@ -290,13 +325,13 @@ namespace app { size_t DumpOutputTensorsToMemory(Model& model, uint8_t* memAddress, const size_t memSize) { const size_t numOutputs = model.GetNumOutputs(); - size_t numBytesWritten = 0; - uint8_t* ptr = memAddress; + size_t numBytesWritten = 0; + uint8_t* ptr = memAddress; /* Iterate over all output tensors. */ for (size_t i = 0; i < numOutputs; ++i) { const TfLiteTensor* tensor = model.GetOutputTensor(i); - const auto* tData = tflite::GetTensorData<uint8_t>(tensor); + const auto* tData = tflite::GetTensorData<uint8_t>(tensor); #if VERIFY_TEST_OUTPUT DumpTensor(tensor); #endif /* VERIFY_TEST_OUTPUT */ @@ -305,15 +340,13 @@ namespace app { if (tensor->bytes > 0) { std::memcpy(ptr, tData, tensor->bytes); - info("Copied %zu bytes for tensor %zu to 0x%p\n", - tensor->bytes, i, ptr); + info("Copied %zu bytes for tensor %zu to 0x%p\n", tensor->bytes, i, ptr); numBytesWritten += tensor->bytes; ptr += tensor->bytes; } } else { - printf_err("Error writing tensor %zu to memory @ 0x%p\n", - i, memAddress); + printf_err("Error writing tensor %zu to memory @ 0x%p\n", i, memAddress); break; } } |