From 4e002791bc6781b549c6951cfe44f918289d7e82 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Wed, 4 May 2022 09:45:02 +0100 Subject: MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's Signed-off-by: Richard Burton Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37 --- source/application/main/include/BaseProcessing.hpp | 6 - source/use_case/ad/include/AdModel.hpp | 8 +- source/use_case/ad/include/AdPostProcessing.hpp | 50 -- source/use_case/ad/include/AdProcessing.hpp | 230 ++++++ source/use_case/ad/src/AdPostProcessing.cc | 115 --- source/use_case/ad/src/AdProcessing.cc | 208 +++++ source/use_case/ad/src/MainLoop.cc | 7 +- source/use_case/ad/src/UseCaseHandler.cc | 291 ++----- source/use_case/asr/include/Wav2LetterModel.hpp | 2 +- source/use_case/kws_asr/include/KwsProcessing.hpp | 138 ++++ .../use_case/kws_asr/include/MicroNetKwsMfcc.hpp | 10 +- .../use_case/kws_asr/include/Wav2LetterModel.hpp | 12 +- .../kws_asr/include/Wav2LetterPostprocess.hpp | 117 +-- .../kws_asr/include/Wav2LetterPreprocess.hpp | 138 ++-- source/use_case/kws_asr/src/KwsProcessing.cc | 212 +++++ source/use_case/kws_asr/src/MainLoop.cc | 125 +-- source/use_case/kws_asr/src/UseCaseHandler.cc | 492 ++++-------- .../use_case/kws_asr/src/Wav2LetterPostprocess.cc | 146 +++- .../use_case/kws_asr/src/Wav2LetterPreprocess.cc | 116 ++- source/use_case/kws_asr/usecase.cmake | 4 +- .../include/RNNoiseFeatureProcessor.hpp | 341 ++++++++ .../noise_reduction/include/RNNoiseProcess.hpp | 337 -------- .../noise_reduction/include/RNNoiseProcessing.hpp | 113 +++ source/use_case/noise_reduction/src/MainLoop.cc | 4 +- .../noise_reduction/src/RNNoiseFeatureProcessor.cc | 892 +++++++++++++++++++++ .../use_case/noise_reduction/src/RNNoiseProcess.cc | 892 --------------------- .../noise_reduction/src/RNNoiseProcessing.cc | 100 +++ .../use_case/noise_reduction/src/UseCaseHandler.cc | 129 +-- tests/use_case/ad/PostProcessTests.cc | 53 -- tests/use_case/kws_asr/MfccTests.cc | 8 +- .../kws_asr/Wav2LetterPostprocessingTest.cc | 142 ++-- .../kws_asr/Wav2LetterPreprocessingTest.cc | 126 ++- .../noise_reduction/RNNoiseProcessingTests.cpp | 8 +- 33 files changed, 2979 insertions(+), 2593 deletions(-) delete mode 100644 source/use_case/ad/include/AdPostProcessing.hpp create mode 100644 source/use_case/ad/include/AdProcessing.hpp delete mode 100644 source/use_case/ad/src/AdPostProcessing.cc create mode 100644 source/use_case/ad/src/AdProcessing.cc create mode 100644 source/use_case/kws_asr/include/KwsProcessing.hpp create mode 100644 source/use_case/kws_asr/src/KwsProcessing.cc create mode 100644 source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp delete mode 100644 source/use_case/noise_reduction/include/RNNoiseProcess.hpp create mode 100644 source/use_case/noise_reduction/include/RNNoiseProcessing.hpp create mode 100644 source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc delete mode 100644 source/use_case/noise_reduction/src/RNNoiseProcess.cc create mode 100644 source/use_case/noise_reduction/src/RNNoiseProcessing.cc delete mode 100644 tests/use_case/ad/PostProcessTests.cc diff --git a/source/application/main/include/BaseProcessing.hpp b/source/application/main/include/BaseProcessing.hpp index c1c3255..c099db2 100644 --- a/source/application/main/include/BaseProcessing.hpp +++ b/source/application/main/include/BaseProcessing.hpp @@ -41,9 +41,6 @@ namespace app { * @return true if successful, false otherwise. **/ virtual bool DoPreProcess(const void* input, size_t inputSize) = 0; - - protected: - Model* m_model = nullptr; }; /** @@ -62,9 +59,6 @@ namespace app { * @return true if successful, false otherwise. **/ virtual bool DoPostProcess() = 0; - - protected: - Model* m_model = nullptr; }; } /* namespace app */ diff --git a/source/use_case/ad/include/AdModel.hpp b/source/use_case/ad/include/AdModel.hpp index 8d914c4..2195a7c 100644 --- a/source/use_case/ad/include/AdModel.hpp +++ b/source/use_case/ad/include/AdModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,12 @@ namespace arm { namespace app { class AdModel : public Model { + + public: + /* Indices for the expected model - based on input tensor shape */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + protected: /** @brief Gets the reference to op resolver interface class */ const tflite::MicroOpResolver& GetOpResolver() override; diff --git a/source/use_case/ad/include/AdPostProcessing.hpp b/source/use_case/ad/include/AdPostProcessing.hpp deleted file mode 100644 index 7eaec84..0000000 --- a/source/use_case/ad/include/AdPostProcessing.hpp +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef ADPOSTPROCESSING_HPP -#define ADPOSTPROCESSING_HPP - -#include "TensorFlowLiteMicro.hpp" - -#include - -namespace arm { -namespace app { - - /** @brief Dequantize TensorFlow Lite Micro tensor. - * @param[in] tensor Pointer to the TensorFlow Lite Micro tensor to be dequantized. - * @return Vector with the dequantized tensor values. - **/ - template - std::vector Dequantize(TfLiteTensor* tensor); - - /** - * @brief Calculates the softmax of vector in place. **/ - void Softmax(std::vector& inputVector); - - - /** @brief Given a wav file name return AD model output index. - * @param[in] wavFileName Audio WAV filename. - * File name should be in format anything_goes_XX_here.wav - * where XX is the machine ID e.g. 00, 02, 04 or 06 - * @return AD model output index as 8 bit integer. - **/ - int8_t OutputIndexFromFileName(std::string wavFileName); - -} /* namespace app */ -} /* namespace arm */ - -#endif /* ADPOSTPROCESSING_HPP */ diff --git a/source/use_case/ad/include/AdProcessing.hpp b/source/use_case/ad/include/AdProcessing.hpp new file mode 100644 index 0000000..9abf6f1 --- /dev/null +++ b/source/use_case/ad/include/AdProcessing.hpp @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AD_PROCESSING_HPP +#define AD_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "AudioUtils.hpp" +#include "AdMelSpectrogram.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for anomaly detection use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class AdPreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor for AdPreProcess class objects + * @param[in] inputTensor input tensor pointer from the tensor arena. + * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length + * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride + * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used. + */ + explicit AdPreProcess(TfLiteTensor* inputTensor, + uint32_t melSpectrogramFrameLen, + uint32_t melSpectrogramFrameStride, + float adModelTrainingMean); + + ~AdPreProcess() = default; + + /** + * @brief Function to invoke pre-processing and populate the input vector + * @param input pointer to input data. For anomaly detection, this is the pointer to + * the audio data. + * @param inputSize Size of the data being passed in for pre-processing. + * @return True if successful, false otherwise. + */ + bool DoPreProcess(const void* input, size_t inputSize) override; + + /** + * @brief Getter function for audio window size computed when constructing + * the class object. + * @return Audio window size as 32 bit unsigned integer. + */ + uint32_t GetAudioWindowSize(); + + /** + * @brief Getter function for audio window stride computed when constructing + * the class object. + * @return Audio window stride as 32 bit unsigned integer. + */ + uint32_t GetAudioDataStride(); + + /** + * @brief Setter function for current audio index. This is only used for evaluating + * if previously computed features can be re-used from cache. + */ + void SetAudioWindowIndex(uint32_t idx); + + private: + bool m_validInstance{false}; /**< Indicates the current object is valid. */ + uint32_t m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */ + uint32_t m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */ + uint8_t m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */ + uint32_t m_numMelSpecVectorsInAudioStride{}; /**< Number of frames to move across the audio. */ + uint32_t m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */ + uint32_t m_audioDataStride{}; /**< Audio window stride computed. */ + uint32_t m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */ + uint32_t m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */ + + audio::SlidingWindow m_melWindowSlider; /**< Internal MEL spectrogram window slider */ + audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */ + std::function&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */ + }; + + class AdPostProcess : public BasePostProcess { + public: + /** + * @brief Constructor for AdPostProcess object. + * @param[in] outputTensor Output tensor pointer. + */ + explicit AdPostProcess(TfLiteTensor* outputTensor); + + ~AdPostProcess() = default; + + /** + * @brief Function to do the post-processing on the output tensor. + * @return True if successful, false otherwise. + */ + bool DoPostProcess() override; + + /** + * @brief Getter function for an element from the de-quantised output vector. + * @param index Index of the element to be retrieved. + * @return index represented as a 32 bit floating point number. + */ + float GetOutputValue(uint32_t index); + + private: + TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */ + std::vector m_dequantizedOutputVec{}; /**< Internal output vector */ + + /** + * @brief De-quantizes and flattens the output tensor into a vector. + * @tparam T template parameter to indicate data type. + * @return True if successful, false otherwise. + */ + template + bool Dequantize() + { + TfLiteTensor* tensor = this->m_outputTensor; + if (tensor == nullptr) { + printf_err("Invalid output tensor.\n"); + return false; + } + T* tensorData = tflite::GetTensorData(tensor); + + uint32_t totalOutputSize = 1; + for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){ + totalOutputSize *= tensor->dims->data[inputDim]; + } + + /* For getting the floating point values, we need quantization parameters */ + QuantParams quantParams = GetTensorQuantParams(tensor); + + this->m_dequantizedOutputVec = std::vector(totalOutputSize, 0); + + for (size_t i = 0; i < totalOutputSize; ++i) { + this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset); + } + + return true; + } + }; + + /* Templated instances available: */ + template bool AdPostProcess::Dequantize(); + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T feature vector type. + * @param inputTensor model input tensor pointer. + * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. + * @param compute features calculator function. + * @return lambda function to compute features. + */ + template + std::function&, size_t, bool, size_t, size_t)> + FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute) + { + /* Feature cache to be captured by lambda function*/ + static std::vector> featureCache = std::vector>(cacheSize); + + return [=](std::vector& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex, + size_t resizeScale) + { + T* tensorData = tflite::GetTensorData(inputTensor); + std::vector features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size() / resizeScale; + auto sizeBytes = sizeof(T); + + /* Input should be transposed and "resized" by skipping elements. */ + for (size_t outIndex = 0; outIndex < size; outIndex++) { + std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes); + } + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex / resizeScale) { + featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features); + } + }; + } + + template std::function&, size_t , bool, size_t, size_t)> + FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector&)> compute); + + template std::function&, size_t, bool, size_t, size_t)> + FeatureCalc(TfLiteTensor *inputTensor, + size_t cacheSize, + std::function(std::vector&)> compute); + + std::function&, int, bool, size_t, size_t)> + GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, + TfLiteTensor* inputTensor, + size_t cacheSize, + float trainingMean); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* AD_PROCESSING_HPP */ diff --git a/source/use_case/ad/src/AdPostProcessing.cc b/source/use_case/ad/src/AdPostProcessing.cc deleted file mode 100644 index c461875..0000000 --- a/source/use_case/ad/src/AdPostProcessing.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "AdPostProcessing.hpp" -#include "log_macros.h" - -#include -#include -#include - -namespace arm { -namespace app { - - template - std::vector Dequantize(TfLiteTensor* tensor) { - - if (tensor == nullptr) { - printf_err("Tensor is null pointer can not dequantize.\n"); - return std::vector(); - } - T* tensorData = tflite::GetTensorData(tensor); - - uint32_t totalOutputSize = 1; - for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){ - totalOutputSize *= tensor->dims->data[inputDim]; - } - - /* For getting the floating point values, we need quantization parameters */ - QuantParams quantParams = GetTensorQuantParams(tensor); - - std::vector dequantizedOutput(totalOutputSize); - - for (size_t i = 0; i < totalOutputSize; ++i) { - dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset); - } - - return dequantizedOutput; - } - - void Softmax(std::vector& inputVector) { - auto start = inputVector.begin(); - auto end = inputVector.end(); - - /* Fix for numerical stability and apply exp. */ - float maxValue = *std::max_element(start, end); - for (auto it = start; it!=end; ++it) { - *it = std::exp((*it) - maxValue); - } - - float sumExp = std::accumulate(start, end, 0.0f); - - for (auto it = start; it!=end; ++it) { - *it = (*it)/sumExp; - } - } - - int8_t OutputIndexFromFileName(std::string wavFileName) { - /* Filename is assumed in the form machine_id_00.wav */ - std::string delimiter = "_"; /* First character used to split the file name up. */ - size_t delimiterStart; - std::string subString; - size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ - - for (size_t i = 0; i < machineIdxInString; ++i) { - delimiterStart = wavFileName.find(delimiter); - subString = wavFileName.substr(0, delimiterStart); - wavFileName.erase(0, delimiterStart + delimiter.length()); - } - - /* At this point substring should be 00.wav */ - delimiter = "."; /* Second character used to split the file name up. */ - delimiterStart = subString.find(delimiter); - subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - - auto is_number = [](const std::string& str) -> bool - { - std::string::const_iterator it = str.begin(); - while (it != str.end() && std::isdigit(*it)) ++it; - return !str.empty() && it == str.end(); - }; - - const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1; - - /* Return corresponding index in the output vector. */ - if (machineIdx == 0) { - return 0; - } else if (machineIdx == 2) { - return 1; - } else if (machineIdx == 4) { - return 2; - } else if (machineIdx == 6) { - return 3; - } else { - printf_err("%d is an invalid machine index \n", machineIdx); - return -1; - } - } - - template std::vector Dequantize(TfLiteTensor* tensor); - template std::vector Dequantize(TfLiteTensor* tensor); -} /* namespace app */ -} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/ad/src/AdProcessing.cc b/source/use_case/ad/src/AdProcessing.cc new file mode 100644 index 0000000..a33131c --- /dev/null +++ b/source/use_case/ad/src/AdProcessing.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AdProcessing.hpp" + +#include "AdModel.hpp" + +namespace arm { +namespace app { + +AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor, + uint32_t melSpectrogramFrameLen, + uint32_t melSpectrogramFrameStride, + float adModelTrainingMean): + m_validInstance{false}, + m_melSpectrogramFrameLen{melSpectrogramFrameLen}, + m_melSpectrogramFrameStride{melSpectrogramFrameStride}, + /**< Model is trained on features downsampled 2x */ + m_inputResizeScale{2}, + /**< We are choosing to move by 20 frames across the audio for each inference. */ + m_numMelSpecVectorsInAudioStride{20}, + m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride}, + m_melSpec{melSpectrogramFrameLen} +{ + if (!inputTensor) { + printf_err("Invalid input tensor provided to pre-process\n"); + return; + } + + TfLiteIntArray* inputShape = inputTensor->dims; + + if (!inputShape) { + printf_err("Invalid input tensor dims\n"); + return; + } + + const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx]; + const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx]; + + /* Deduce the data length required for 1 inference from the network parameters. */ + this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) * + melSpectrogramFrameStride) + + melSpectrogramFrameLen; + this->m_numReusedFeatureVectors = kNumRows - + (this->m_numMelSpecVectorsInAudioStride / + this->m_inputResizeScale); + this->m_melSpec.Init(); + + /* Creating a Mel Spectrogram sliding window for the data required for 1 inference. + * "resizing" done here by multiplying stride by resize scale. */ + this->m_melWindowSlider = audio::SlidingWindow( + nullptr, /* to be populated later. */ + this->m_audioDataWindowSize, + melSpectrogramFrameLen, + melSpectrogramFrameStride * this->m_inputResizeScale); + + /* Construct feature calculation function. */ + this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor, + this->m_numReusedFeatureVectors, + adModelTrainingMean); + this->m_validInstance = true; +} + +bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize) +{ + /* Check that we have a valid instance. */ + if (!this->m_validInstance) { + printf_err("Invalid pre-processor instance\n"); + return false; + } + + /* We expect that we can traverse the size with which the MEL spectrogram + * sliding window was initialised with. */ + if (!input || inputSize < this->m_audioDataWindowSize) { + printf_err("Invalid input provided for pre-processing\n"); + return false; + } + + /* We moved to the next window - set the features sliding to the new address. */ + this->m_melWindowSlider.Reset(static_cast(input)); + + /* The first window does not have cache ready. */ + const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0; + + /* Start calculating features inside one audio sliding window. */ + while (this->m_melWindowSlider.HasNext()) { + const int16_t* melSpecWindow = this->m_melWindowSlider.Next(); + std::vector melSpecAudioData = std::vector( + melSpecWindow, + melSpecWindow + this->m_melSpectrogramFrameLen); + + /* Compute features for this window and write them to input tensor. */ + this->m_featureCalc(melSpecAudioData, + this->m_melWindowSlider.Index(), + useCache, + this->m_numMelSpecVectorsInAudioStride, + this->m_inputResizeScale); + } + + return true; +} + +uint32_t AdPreProcess::GetAudioWindowSize() +{ + return this->m_audioDataWindowSize; +} + +uint32_t AdPreProcess::GetAudioDataStride() +{ + return this->m_audioDataStride; +} + +void AdPreProcess::SetAudioWindowIndex(uint32_t idx) +{ + this->m_audioWindowIndex = idx; +} + +AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) : + m_outputTensor {outputTensor} +{} + +bool AdPostProcess::DoPostProcess() +{ + switch (this->m_outputTensor->type) { + case kTfLiteInt8: + this->Dequantize(); + break; + default: + printf_err("Unsupported tensor type"); + return false; + } + + math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec); + return true; +} + +float AdPostProcess::GetOutputValue(uint32_t index) +{ + if (index < this->m_dequantizedOutputVec.size()) { + return this->m_dequantizedOutputVec[index]; + } + printf_err("Invalid index for output\n"); + return 0.0; +} + +std::function&, int, bool, size_t, size_t)> +GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, + TfLiteTensor* inputTensor, + size_t cacheSize, + float trainingMean) +{ + std::function&, size_t, bool, size_t, size_t)> melSpecFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + + auto* quantParams = static_cast(quant.params); + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + melSpecFeatureCalc = FeatureCalc( + inputTensor, + cacheSize, + [=, &melSpec](std::vector& audioDataWindow) { + return melSpec.MelSpecComputeQuant( + audioDataWindow, + quantScale, + quantOffset, + trainingMean); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + } else { + melSpecFeatureCalc = FeatureCalc( + inputTensor, + cacheSize, + [=, &melSpec]( + std::vector& audioDataWindow) { + return melSpec.ComputeMelSpec( + audioDataWindow, + trainingMean); + }); + } + return melSpecFeatureCalc; +} + +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc index 23d1e51..140359b 100644 --- a/source/use_case/ad/src/MainLoop.cc +++ b/source/use_case/ad/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions */ #include "InputFiles.hpp" /* For input data */ #include "AdModel.hpp" /* Model class for running inference */ #include "UseCaseCommonUtils.hpp" /* Utils functions */ @@ -63,8 +62,8 @@ void main_loop() caseContext.Set("profiler", profiler); caseContext.Set("model", model); caseContext.Set("clipIndex", 0); - caseContext.Set("frameLength", g_FrameLength); - caseContext.Set("frameStride", g_FrameStride); + caseContext.Set("frameLength", g_FrameLength); + caseContext.Set("frameStride", g_FrameStride); caseContext.Set("scoreThreshold", g_ScoreThreshold); caseContext.Set("trainingMean", g_TrainingMean); diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc index 5585f36..0179d6b 100644 --- a/source/use_case/ad/src/UseCaseHandler.cc +++ b/source/use_case/ad/src/UseCaseHandler.cc @@ -24,8 +24,8 @@ #include "AudioUtils.hpp" #include "ImageUtils.hpp" #include "UseCaseCommonUtils.hpp" -#include "AdPostProcessing.hpp" #include "log_macros.h" +#include "AdProcessing.hpp" namespace arm { namespace app { @@ -39,32 +39,17 @@ namespace app { **/ static bool PresentInferenceResult(float result, float threshold); - /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MelSpe data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] melSpec MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors). - * @param[in] trainingMean Training mean. - * @return function function to be called providing audio sample and sliding window index. - */ - static std::function&, int, bool, size_t, size_t)> - GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, - TfLiteTensor* inputTensor, - size_t cacheSize, - float trainingMean); - - /* Vibration classification handler */ + /** @brief Given a wav file name return AD model output index. + * @param[in] wavFileName Audio WAV filename. + * File name should be in format anything_goes_XX_here.wav + * where XX is the machine ID e.g. 00, 02, 04 or 06 + * @return AD model output index as 8 bit integer. + **/ + static int8_t OutputIndexFromFileName(std::string wavFileName); + + /* Anomaly Detection inference handler */ bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - auto& profiler = ctx.Get("profiler"); - constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -81,8 +66,9 @@ namespace app { return false; } - const auto frameLength = ctx.Get("frameLength"); - const auto frameStride = ctx.Get("frameStride"); + auto& profiler = ctx.Get("profiler"); + const auto melSpecFrameLength = ctx.Get("frameLength"); + const auto melSpecFrameStride = ctx.Get("frameStride"); const auto scoreThreshold = ctx.Get("scoreThreshold"); const auto trainingMean = ctx.Get("trainingMean"); auto startClipIdx = ctx.Get("clipIndex"); @@ -95,21 +81,13 @@ namespace app { return false; } - TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t kNumRows = inputShape->data[1]; - const uint32_t kNumCols = inputShape->data[2]; + AdPreProcess preProcess{ + inputTensor, + melSpecFrameLength, + melSpecFrameStride, + trainingMean}; - audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength); - melSpec.Init(); - - /* Deduce the data length required for 1 inference from the network parameters. */ - const uint8_t inputResizeScale = 2; - const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength; - - /* We are choosing to move by 20 frames across the audio for each inference. */ - const uint8_t nMelSpecVectorsInAudioStride = 20; - - auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride; + AdPostProcess postProcess{outputTensor}; do { hal_lcd_clear(COLOR_BLACK); @@ -122,29 +100,12 @@ namespace app { return false; } - /* Creating a Mel Spectrogram sliding window for the data required for 1 inference. - * "resizing" done here by multiplying stride by resize scale. */ - auto audioMelSpecWindowSlider = audio::SlidingWindow( - get_audio_array(currentIndex), - audioDataWindowSize, frameLength, - frameStride * inputResizeScale); - /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - audioDataWindowSize, audioDataStride); - - /* Calculate number of the feature vectors in the window overlap region taking into account resizing. - * These feature vectors will be reused.*/ - auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale); - - /* Construct feature calculation function. */ - auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor, - numberOfReusedFeatureVectors, trainingMean); - if (!melSpecFeatureCalc){ - return false; - } + get_audio_array(currentIndex), + get_audio_array_size(currentIndex), + preProcess.GetAudioWindowSize(), + preProcess.GetAudioDataStride()); /* Result is an averaged sum over inferences. */ float result = 0; @@ -152,30 +113,18 @@ namespace app { /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex)); + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + info("Running inference on audio clip %" PRIu32 " => %s\n", + currentIndex, get_filename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { - const int16_t *inferenceWindow = audioDataSlider.Next(); - - /* We moved to the next window - set the features sliding to the new address. */ - audioMelSpecWindowSlider.Reset(inferenceWindow); + const int16_t* inferenceWindow = audioDataSlider.Next(); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (audioMelSpecWindowSlider.HasNext()) { - const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next(); - std::vector melSpecAudioData = std::vector(melSpecWindow, - melSpecWindow + frameLength); - - /* Compute features for this window and write them to input tensor. */ - melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(), - useCache, nMelSpecVectorsInAudioStride, inputResizeScale); - } + preProcess.SetAudioWindowIndex(audioDataSlider.Index()); + preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize()); info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); @@ -185,13 +134,11 @@ namespace app { return false; } - /* Use the negative softmax score of the corresponding index as the outlier score */ - std::vector dequantOutput = Dequantize(outputTensor); - Softmax(dequantOutput); - result += -dequantOutput[machineOutputIndex]; + postProcess.DoPostProcess(); + result += 0 - postProcess.GetOutputValue(machineOutputIndex); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(outputTensor); + DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ @@ -218,7 +165,6 @@ namespace app { return true; } - static bool PresentInferenceResult(float result, float threshold) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -251,148 +197,47 @@ namespace app { return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - */ - template - std::function&, size_t, bool, size_t, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function (std::vector& )> compute) + static int8_t OutputIndexFromFileName(std::string wavFileName) { - /* Feature cache to be captured by lambda function*/ - static std::vector> featureCache = std::vector>(cacheSize); - - return [=](std::vector& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex, - size_t resizeScale) - { - T *tensorData = tflite::GetTensorData(inputTensor); - std::vector features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size() / resizeScale; - auto sizeBytes = sizeof(T); + /* Filename is assumed in the form machine_id_00.wav */ + std::string delimiter = "_"; /* First character used to split the file name up. */ + size_t delimiterStart; + std::string subString; + size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ + + for (size_t i = 0; i < machineIdxInString; ++i) { + delimiterStart = wavFileName.find(delimiter); + subString = wavFileName.substr(0, delimiterStart); + wavFileName.erase(0, delimiterStart + delimiter.length()); + } - /* Input should be transposed and "resized" by skipping elements. */ - for (size_t outIndex = 0; outIndex < size; outIndex++) { - std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes); - } + /* At this point substring should be 00.wav */ + delimiter = "."; /* Second character used to split the file name up. */ + delimiterStart = subString.find(delimiter); + subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex / resizeScale) { - featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features); - } + auto is_number = [](const std::string& str) -> bool + { + std::string::const_iterator it = str.begin(); + while (it != str.end() && std::isdigit(*it)) ++it; + return !str.empty() && it == str.end(); }; - } - - template std::function&, size_t , bool, size_t, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector&)> compute); - - template std::function&, size_t , bool, size_t, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector&)> compute); - - template std::function&, size_t , bool, size_t, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector&)> compute); - - template std::function&, size_t, bool, size_t, size_t)> - FeatureCalc(TfLiteTensor *inputTensor, - size_t cacheSize, - std::function(std::vector&)> compute); - - - static std::function&, int, bool, size_t, size_t)> - GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean) - { - std::function&, size_t, bool, size_t, size_t)> melSpecFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto *quantParams = (TfLiteAffineQuantization *) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - melSpecFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &melSpec](std::vector& audioDataWindow) { - return melSpec.MelSpecComputeQuant( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - case kTfLiteUInt8: { - melSpecFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &melSpec](std::vector& audioDataWindow) { - return melSpec.MelSpecComputeQuant( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - case kTfLiteInt16: { - melSpecFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &melSpec](std::vector& audioDataWindow) { - return melSpec.MelSpecComputeQuant( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - + const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1; + + /* Return corresponding index in the output vector. */ + if (machineIdx == 0) { + return 0; + } else if (machineIdx == 2) { + return 1; + } else if (machineIdx == 4) { + return 2; + } else if (machineIdx == 6) { + return 3; } else { - melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &melSpec]( - std::vector& audioDataWindow) { - return melSpec.ComputeMelSpec( - audioDataWindow, - trainingMean); - }); + printf_err("%d is an invalid machine index \n", machineIdx); + return -1; } - return melSpecFeatureCalc; } } /* namespace app */ diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp index 0078e44..bec70ab 100644 --- a/source/use_case/asr/include/Wav2LetterModel.hpp +++ b/source/use_case/asr/include/Wav2LetterModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved.rved. + * Copyright (c) 2021 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/source/use_case/kws_asr/include/KwsProcessing.hpp b/source/use_case/kws_asr/include/KwsProcessing.hpp new file mode 100644 index 0000000..d3de3b3 --- /dev/null +++ b/source/use_case/kws_asr/include/KwsProcessing.hpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_PROCESSING_HPP +#define KWS_PROCESSING_HPP + +#include +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "Classifier.hpp" +#include "MicroNetKwsMfcc.hpp" + +#include + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Keyword Spotting use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class KwsPreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numFeatures How many MFCC features to use. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when + * sliding a window through the audio sample. + * @param[in] mfccFrameStride Number of audio samples between consecutive windows. + **/ + explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames, + int mfccFrameLength, int mfccFrameStride); + + /** + * @brief Should perform pre-processing of 'raw' input audio data and load it into + * TFLite Micro input tensors ready for inference. + * @param[in] input Pointer to the data that pre-processing will work on. + * @param[in] inputSize Size of the input data. + * @return true if successful, false otherwise. + **/ + bool DoPreProcess(const void* input, size_t inputSize) override; + + size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */ + size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */ + size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */ + + private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ + const int m_mfccFrameLength; + const int m_mfccFrameStride; + const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */ + + audio::MicroNetKwsMFCC m_mfcc; + audio::SlidingWindow m_mfccSlidingWindow; + size_t m_numMfccVectorsInAudioStride; + size_t m_numReusedMfccVectors; + std::function&, int, bool, size_t)> m_mfccFeatureCalculator; + + /** + * @brief Returns a function to perform feature calculation and populates input tensor data with + * MFCC data. + * + * Input tensor data type check is performed to choose correct MFCC feature data type. + * If tensor has an integer data type then original features are quantised. + * + * Warning: MFCC calculator provided as input must have the same life scope as returned function. + * + * @param[in] mfcc MFCC feature calculator. + * @param[in,out] inputTensor Input tensor pointer to store calculated features. + * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). + * @return Function to be called providing audio sample and sliding window index. + */ + std::function&, int, bool, size_t)> + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, + TfLiteTensor* inputTensor, + size_t cacheSize); + + template + std::function&, size_t, bool, size_t)> + FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute); + }; + + /** + * @brief Post-processing class for Keyword Spotting use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class KwsPostProcess : public BasePostProcess { + + private: + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + Classifier& m_kwsClassifier; /* KWS Classifier object. */ + const std::vector& m_labels; /* KWS Labels. */ + std::vector& m_results; /* Results vector for a single inference. */ + + public: + /** + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. + **/ + KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, + const std::vector& labels, + std::vector& results); + + /** + * @brief Should perform post-processing of the result of inference then + * populate KWS result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_PROCESSING_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp index 43bd390..af6ba5f 100644 --- a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp +++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,7 +24,7 @@ namespace app { namespace audio { /* Class to provide MicroNet specific MFCC calculation requirements. */ - class MicroNetMFCC : public MFCC { + class MicroNetKwsMFCC : public MFCC { public: static constexpr uint32_t ms_defaultSamplingFreq = 16000; @@ -34,14 +34,14 @@ namespace audio { static constexpr bool ms_defaultUseHtkMethod = true; - explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen) + explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen) : MFCC(MfccParams( ms_defaultSamplingFreq, ms_defaultNumFbankBins, ms_defaultMelLoFreq, ms_defaultMelHiFreq, numFeats, frameLen, ms_defaultUseHtkMethod)) {} - MicroNetMFCC() = delete; - ~MicroNetMFCC() = default; + MicroNetKwsMFCC() = delete; + ~MicroNetKwsMFCC() = default; }; } /* namespace audio */ diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp index 7c327b3..0e1adc5 100644 --- a/source/use_case/kws_asr/include/Wav2LetterModel.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,14 +34,18 @@ namespace arm { namespace app { class Wav2LetterModel : public Model { - + public: /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 1; - static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; + /* Model specific constants. */ + static constexpr uint32_t ms_blankTokenIdx = 28; + static constexpr uint32_t ms_numMfccFeatures = 13; + protected: /** @brief Gets the reference to op resolver interface class. */ const tflite::MicroOpResolver& GetOpResolver() override; diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp index 029a641..d1bc9a2 100644 --- a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,88 +14,95 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_WAV2LET_POSTPROC_HPP -#define KWS_ASR_WAV2LET_POSTPROC_HPP +#ifndef KWS_ASR_WAV2LETTER_POSTPROCESS_HPP +#define KWS_ASR_WAV2LETTER_POSTPROCESS_HPP -#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers */ +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "BaseProcessing.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" +#include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /** * @brief Helper class to manage tensor post-processing for "wav2letter" * output. */ - class Postprocess { + class AsrPostProcess : public BasePostProcess { public: + bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ + /** - * @brief Constructor - * @param[in] contextLen Left and right context length for - * output tensor. - * @param[in] innerLen This is the length of the section - * between left and right context. - * @param[in] blankTokenIdx Blank token index. + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in] outputContextLen Left/right context length for output tensor. + * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. + * @param[in] reductionAxis The axis that the logits of each time step is on. **/ - Postprocess(uint32_t contextLen, - uint32_t innerLen, - uint32_t blankTokenIdx); - - Postprocess() = delete; - ~Postprocess() = default; + AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); /** - * @brief Erases the required part of the tensor based - * on context lengths set up during initialisation - * @param[in] tensor Pointer to the tensor - * @param[in] axisIdx Index of the axis on which erase is - * performed. - * @param[in] lastIteration Flag to signal is this is the - * last iteration in which case - * the right context is preserved. - * @return true if successful, false otherwise. - */ - bool Invoke(TfLiteTensor* tensor, - uint32_t axisIdx, - bool lastIteration = false); + * @brief Should perform post-processing of the result of inference then + * populate ASR result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + /** @brief Gets the output inner length for post-processing. */ + static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen); + + /** @brief Gets the output context length (left/right) for post-processing. */ + static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen); + + /** @brief Gets the number of feature vectors to be computed. */ + static uint32_t GetNumFeatureVectors(const Model& model); private: - uint32_t m_contextLen; /* Lengths of left and right contexts. */ - uint32_t m_innerLen; /* Length of inner context. */ - uint32_t m_totalLen; /* Total length of the required axis. */ - uint32_t m_countIterations; /* Current number of iterations. */ - uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + AsrClassifier& m_classifier; /* ASR Classifier object. */ + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + const std::vector& m_labels; /* ASR Labels. */ + asr::ResultVec & m_results; /* Results vector for a single inference. */ + uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */ + uint32_t m_outputInnerLen; /* Length of output inner context. */ + uint32_t m_totalLen; /* Total length of the required axis. */ + uint32_t m_countIterations; /* Current number of iterations. */ + uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */ + /** - * @brief Checks if the tensor and axis index are valid - * inputs to the object - based on how it has been - * initialised. - * @return true if valid, false otherwise. + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been initialised. + * @return true if valid, false otherwise. */ bool IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const; + uint32_t axisIdx) const; /** - * @brief Gets the tensor data element size in bytes based - * on the tensor type. - * @return Size in bytes, 0 if not supported. + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. */ - uint32_t GetTensorElementSize(TfLiteTensor* tensor); + static uint32_t GetTensorElementSize(TfLiteTensor* tensor); /** - * @brief Erases sections from the data assuming row-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. */ bool EraseSectionsRowWise(uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration); - + uint32_t strideSzBytes, + bool lastIteration); }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_WAV2LET_POSTPROC_HPP */ \ No newline at end of file +#endif /* KWS_ASR_WAV2LETTER_POSTPROCESS_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp index 3609c49..1224c23 100644 --- a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,56 +14,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_WAV2LET_PREPROC_HPP -#define KWS_ASR_WAV2LET_PREPROC_HPP +#ifndef KWS_ASR_WAV2LETTER_PREPROCESS_HPP +#define KWS_ASR_WAV2LETTER_PREPROCESS_HPP #include "Wav2LetterModel.hpp" #include "Wav2LetterMfcc.hpp" #include "AudioUtils.hpp" #include "DataStructures.hpp" +#include "BaseProcessing.hpp" #include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /* Class to facilitate pre-processing calculation for Wav2Letter model * for ASR. */ - using AudioWindow = SlidingWindow ; + using AudioWindow = audio::SlidingWindow; - class Preprocess { + class AsrPreProcess : public BasePreProcess { public: /** - * @brief Constructor - * @param[in] numMfccFeatures Number of MFCC features per window. - * @param[in] windowLen Number of elements in a window. - * @param[in] windowStride Stride (in number of elements) for - * moving the window. - * @param[in] numMfccVectors Number of MFCC vectors per window. - */ - Preprocess( - uint32_t numMfccFeatures, - uint32_t windowLen, - uint32_t windowStride, - uint32_t numMfccVectors); - Preprocess() = delete; - ~Preprocess() = default; + * @brief Constructor. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. + * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. + */ + AsrPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t numFeatureFrames, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); /** * @brief Calculates the features required from audio data. This * includes MFCC, first and second order deltas, * normalisation and finally, quantisation. The tensor is - * populated with feature from a given window placed along + * populated with features from a given window placed along * in a single row. * @param[in] audioData Pointer to the first element of audio data. * @param[in] audioDataLen Number of elements in the audio data. - * @param[in] tensor Tensor to be populated. * @return true if successful, false in case of error. */ - bool Invoke(const int16_t * audioData, - uint32_t audioDataLen, - TfLiteTensor * tensor); + bool DoPreProcess(const void* audioData, size_t audioDataLen) override; protected: /** @@ -73,49 +68,32 @@ namespace asr { * @param[in] mfcc MFCC buffers. * @param[out] delta1 Result of the first diff computation. * @param[out] delta2 Result of the second diff computation. - * - * @return true if successful, false otherwise. + * @return true if successful, false otherwise. */ static bool ComputeDeltas(Array2d& mfcc, Array2d& delta1, Array2d& delta2); /** - * @brief Given a 2D vector of floats, computes the mean. - * @param[in] vec Vector of vector of floats. - * @return Mean value. - */ - static float GetMean(Array2d& vec); - - /** - * @brief Given a 2D vector of floats, computes the stddev. - * @param[in] vec Vector of vector of floats. - * @param[in] mean Mean value of the vector passed in. - * @return stddev value. - */ - static float GetStdDev(Array2d& vec, - const float mean); - - /** - * @brief Given a 2D vector of floats, normalises it using - * the mean and the stddev + * @brief Given a 2D vector of floats, rescale it to have mean of 0 and + * standard deviation of 1. * @param[in,out] vec Vector of vector of floats. */ - static void NormaliseVec(Array2d& vec); + static void StandardizeVecF32(Array2d& vec); /** - * @brief Normalises the MFCC and delta buffers. + * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1. */ - void Normalise(); + void Standarize(); /** * @brief Given the quantisation and data type limits, computes * the quantised values of a floating point input data. - * @param[in] elem Element to be quantised. - * @param[in] quantScale Scale. - * @param[in] quantOffset Offset. - * @param[in] minVal Numerical limit - minimum. - * @param[in] maxVal Numerical limit - maximum. + * @param[in] elem Element to be quantised. + * @param[in] quantScale Scale. + * @param[in] quantOffset Offset. + * @param[in] minVal Numerical limit - minimum. + * @param[in] maxVal Numerical limit - maximum. * @return Floating point quantised value. */ static float GetQuantElem( @@ -133,44 +111,43 @@ namespace asr { * this being the convolution speed up (as we can use * contiguous memory). The output, however, requires the * time axis to be in column major arrangement. - * @param[in] outputBuf Pointer to the output buffer. - * @param[in] outputBufSz Output buffer's size. - * @param[in] quantScale Quantisation scale. - * @param[in] quantOffset Quantisation offset. + * @param[in] outputBuf Pointer to the output buffer. + * @param[in] outputBufSz Output buffer's size. + * @param[in] quantScale Quantisation scale. + * @param[in] quantOffset Quantisation offset. */ template bool Quantise( - T * outputBuf, + T* outputBuf, const uint32_t outputBufSz, const float quantScale, const int quantOffset) { - /* Check the output size will for everything. */ + /* Check the output size will fit everything. */ if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) { printf_err("Tensor size too small for features\n"); return false; } /* Populate. */ - T * outputBufMfcc = outputBuf; - T * outputBufD1 = outputBuf + this->m_numMfccFeats; - T * outputBufD2 = outputBufD1 + this->m_numMfccFeats; + T* outputBufMfcc = outputBuf; + T* outputBufD1 = outputBuf + this->m_numMfccFeats; + T* outputBufD2 = outputBufD1 + this->m_numMfccFeats; const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ const float minVal = std::numeric_limits::min(); const float maxVal = std::numeric_limits::max(); - /* We need to do a transpose while copying and concatenating - * the tensor. */ - for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) { + /* Need to transpose while copying and concatenating the tensor. */ + for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast(this->GetQuantElem( + *outputBufMfcc++ = static_cast(AsrPreProcess::GetQuantElem( this->m_mfccBuf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast(this->GetQuantElem( + *outputBufD1++ = static_cast(AsrPreProcess::GetQuantElem( this->m_delta1Buf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast(this->GetQuantElem( + *outputBufD2++ = static_cast(AsrPreProcess::GetQuantElem( this->m_delta2Buf(i, j), quantScale, quantOffset, minVal, maxVal)); } @@ -183,24 +160,23 @@ namespace asr { } private: - Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + TfLiteTensor* m_inputTensor; /* Model input tensor. */ /* Actual buffers to be populated. */ - Array2d m_mfccBuf; /* Contiguous buffer 1D: MFCC */ - Array2d m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ - Array2d m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + Array2d m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ - uint32_t m_windowLen; /* Window length for MFCC. */ - uint32_t m_windowStride; /* Window stride len for MFCC. */ - uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ - uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */ - AudioWindow m_window; /* Sliding window. */ + uint32_t m_mfccWindowLen; /* Window length for MFCC. */ + uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */ + uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */ + AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */ }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_WAV2LET_PREPROC_HPP */ \ No newline at end of file +#endif /* KWS_ASR_WAV2LETTER_PREPROCESS_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/KwsProcessing.cc b/source/use_case/kws_asr/src/KwsProcessing.cc new file mode 100644 index 0000000..328709d --- /dev/null +++ b/source/use_case/kws_asr/src/KwsProcessing.cc @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "KwsProcessing.hpp" +#include "ImageUtils.hpp" +#include "log_macros.h" +#include "MicroNetKwsModel.hpp" + +namespace arm { +namespace app { + + KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames, + int mfccFrameLength, int mfccFrameStride + ): + m_inputTensor{inputTensor}, + m_mfccFrameLength{mfccFrameLength}, + m_mfccFrameStride{mfccFrameStride}, + m_numMfccFrames{numMfccFrames}, + m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)} + { + this->m_mfcc.Init(); + + /* Deduce the data length required for 1 inference from the network parameters. */ + this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride + + (this->m_mfccFrameLength - this->m_mfccFrameStride); + + /* Creating an MFCC feature sliding window for the data required for 1 inference. */ + this->m_mfccSlidingWindow = audio::SlidingWindow(nullptr, this->m_audioDataWindowSize, + this->m_mfccFrameLength, this->m_mfccFrameStride); + + /* For longer audio clips we choose to move by half the audio window size + * => for a 1 second window size there is an overlap of 0.5 seconds. */ + this->m_audioDataStride = this->m_audioDataWindowSize / 2; + + /* To have the previously calculated features re-usable, stride must be multiple + * of MFCC features window stride. Reduce stride through audio if needed. */ + if (0 != this->m_audioDataStride % this->m_mfccFrameStride) { + this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride; + } + + this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride; + + /* Calculate number of the feature vectors in the window overlap region. + * These feature vectors will be reused.*/ + this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1 + - this->m_numMfccVectorsInAudioStride; + + /* Construct feature calculation function. */ + this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor, + this->m_numReusedMfccVectors); + + if (!this->m_mfccFeatureCalculator) { + printf_err("Feature calculator not initialized."); + } + } + + bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize) + { + UNUSED(inputSize); + if (data == nullptr) { + printf_err("Data pointer is null"); + } + + /* Set the features sliding window to the new address. */ + auto input = static_cast(data); + this->m_mfccSlidingWindow.Reset(input); + + /* Cache is only usable if we have more than 1 inference in an audio clip. */ + bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0; + + /* Use a sliding window to calculate MFCC features frame by frame. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); + + std::vector mfccFrameAudioData = std::vector(mfccWindow, + mfccWindow + this->m_mfccFrameLength); + + /* Compute features for this window and write them to input tensor. */ + this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(), + useCache, this->m_numMfccVectorsInAudioStride); + } + + debug("Input tensor populated \n"); + + return true; + } + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T Feature vector type. + * @param[in] inputTensor Model input tensor pointer. + * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap. + * @param[in] compute Features calculator function. + * @return Lambda function to compute features. + */ + template + std::function&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute) + { + /* Feature cache to be captured by lambda function. */ + static std::vector> featureCache = std::vector>(cacheSize); + + return [=](std::vector& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex) + { + T* tensorData = tflite::GetTensorData(inputTensor); + std::vector features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size(); + auto sizeBytes = sizeof(T) * size; + std::memcpy(tensorData + (index * size), features.data(), sizeBytes); + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex) { + featureCache[index - featuresOverlapIndex] = std::move(features); + } + }; + } + + template std::function&, size_t , bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector&)> compute); + + template std::function&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function(std::vector&)> compute); + + + std::function&, int, bool, size_t)> + KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + { + std::function&, size_t, bool, size_t)> mfccFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + auto *quantParams = (TfLiteAffineQuantization *) quant.params; + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + mfccFeatureCalc = this->FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + } else { + mfccFeatureCalc = this->FeatureCalc(inputTensor, cacheSize, + [&mfcc](std::vector& audioDataWindow) { + return mfcc.MfccCompute(audioDataWindow); } + ); + } + return mfccFeatureCalc; + } + + KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, + const std::vector& labels, + std::vector& results) + :m_outputTensor{outputTensor}, + m_kwsClassifier{classifier}, + m_labels{labels}, + m_results{results} + {} + + bool KwsPostProcess::DoPostProcess() + { + return this->m_kwsClassifier.GetClassificationResults( + this->m_outputTensor, this->m_results, + this->m_labels, 1, true); + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index 5c1d0e0..f1d97a0 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "InputFiles.hpp" /* For input images. */ #include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ @@ -24,8 +23,6 @@ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ -#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */ -#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */ #include "log_macros.h" using KwsClassifier = arm::app::Classifier; @@ -53,19 +50,8 @@ static void DisplayMenu() fflush(stdout); } -/** @brief Gets the number of MFCC features for a single window. */ -static uint32_t GetNumMfccFeatures(const arm::app::Model& model); - -/** @brief Gets the number of MFCC feature vectors to be computed. */ -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); - -/** @brief Gets the output context length (left and right) for post-processing. */ -static uint32_t GetOutputContextLen(const arm::app::Model& model, - uint32_t inputCtxLen); - -/** @brief Gets the output inner length for post-processing. */ -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - uint32_t outputCtxLen); +/** @brief Verify input and output tensor are of certain min dimensions. */ +static bool VerifyTensorDimensions(const arm::app::Model& model); void main_loop() { @@ -84,61 +70,46 @@ void main_loop() if (!asrModel.Init(kwsModel.GetAllocator())) { printf_err("Failed to initialise ASR model\n"); return; + } else if (!VerifyTensorDimensions(asrModel)) { + printf_err("Model's input or output dimension verification failed\n"); + return; } - /* Initialise ASR pre-processing. */ - arm::app::audio::asr::Preprocess prep( - GetNumMfccFeatures(asrModel), - arm::app::asr::g_FrameLength, - arm::app::asr::g_FrameStride, - GetNumMfccFeatureVectors(asrModel)); - - /* Initialise ASR post-processing. */ - const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen); - const uint32_t blankTokenIdx = 28; - arm::app::audio::asr::Postprocess postp( - outputCtxLen, - GetOutputInnerLen(asrModel, outputCtxLen), - blankTokenIdx); - /* Instantiate application context. */ arm::app::ApplicationContext caseContext; arm::app::Profiler profiler{"kws_asr"}; caseContext.Set("profiler", profiler); - caseContext.Set("kwsmodel", kwsModel); - caseContext.Set("asrmodel", asrModel); + caseContext.Set("kwsModel", kwsModel); + caseContext.Set("asrModel", asrModel); caseContext.Set("clipIndex", 0); caseContext.Set("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */ - caseContext.Set("kwsframeLength", arm::app::kws::g_FrameLength); - caseContext.Set("kwsframeStride", arm::app::kws::g_FrameStride); - caseContext.Set("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set("kwsFrameLength", arm::app::kws::g_FrameLength); + caseContext.Set("kwsFrameStride", arm::app::kws::g_FrameStride); + caseContext.Set("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ caseContext.Set("kwsNumMfcc", arm::app::kws::g_NumMfcc); caseContext.Set("kwsNumAudioWins", arm::app::kws::g_NumAudioWins); - caseContext.Set("asrframeLength", arm::app::asr::g_FrameLength); - caseContext.Set("asrframeStride", arm::app::asr::g_FrameStride); - caseContext.Set("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set("asrFrameLength", arm::app::asr::g_FrameLength); + caseContext.Set("asrFrameStride", arm::app::asr::g_FrameStride); + caseContext.Set("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ KwsClassifier kwsClassifier; /* Classifier wrapper object. */ arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */ - caseContext.Set("kwsclassifier", kwsClassifier); - caseContext.Set("asrclassifier", asrClassifier); - - caseContext.Set("preprocess", prep); - caseContext.Set("postprocess", postp); + caseContext.Set("kwsClassifier", kwsClassifier); + caseContext.Set("asrClassifier", asrClassifier); std::vector asrLabels; arm::app::asr::GetLabelsVector(asrLabels); std::vector kwsLabels; arm::app::kws::GetLabelsVector(kwsLabels); - caseContext.Set&>("asrlabels", asrLabels); - caseContext.Set&>("kwslabels", kwsLabels); + caseContext.Set&>("asrLabels", asrLabels); + caseContext.Set&>("kwsLabels", kwsLabels); /* KWS keyword that triggers ASR and associated checks */ - std::string triggerKeyword = std::string("yes"); + std::string triggerKeyword = std::string("no"); if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) { - caseContext.Set("triggerkeyword", triggerKeyword); + caseContext.Set("triggerKeyword", triggerKeyword); } else { printf_err("Selected trigger keyword not found in labels file\n"); @@ -196,50 +167,26 @@ void main_loop() info("Main loop terminated.\n"); } -static uint32_t GetNumMfccFeatures(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; - if (0 != inputCols % 3) { - printf_err("Number of input columns is not a multiple of 3\n"); - } - return std::max(inputCols/3, 0); -} - -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) +static bool VerifyTensorDimensions(const arm::app::Model& model) { + /* Populate tensor related parameters. */ TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - return std::max(inputRows, 0); -} - -static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) -{ - const uint32_t inputRows = GetNumMfccFeatureVectors(model); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - - /* Check to make sure that the input tensor supports the above context and inner lengths. */ - if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { - printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", - inputCtxLen); - return 0; + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; } TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - - const float tensorColRatio = static_cast(inputRows)/ - static_cast(outputRows); - - return std::round(static_cast(inputCtxLen)/tensorColRatio); -} + if (!outputTensor->dims) { + printf_err("Invalid output tensor dims\n"); + return false; + } else if (outputTensor->dims->size < 3) { + printf_err("Output tensor dimension should be >= 3\n"); + return false; + } -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - const uint32_t outputCtxLen) -{ - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - return (outputRows - (2 * outputCtxLen)); + return true; } diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 1e1a400..01aefae 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -28,6 +28,7 @@ #include "Wav2LetterMfcc.hpp" #include "Wav2LetterPreprocess.hpp" #include "Wav2LetterPostprocess.hpp" +#include "KwsProcessing.hpp" #include "AsrResult.hpp" #include "AsrClassifier.hpp" #include "OutputDecode.hpp" @@ -39,11 +40,6 @@ using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { - enum AsrOutputReductionAxis { - AxisRow = 1, - AxisCol = 2 - }; - struct KWSOutput { bool executionSuccess = false; const int16_t* asrAudioStart = nullptr; @@ -51,73 +47,53 @@ namespace app { }; /** - * @brief Presents kws inference results using the data presentation - * object. - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents KWS inference results. + * @param[in] results Vector of KWS classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector& results); + static bool PresentInferenceResult(std::vector& results); /** - * @brief Presents asr inference results using the data presentation - * object. - * @param[in] platform reference to the hal platform object - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector& results); + static bool PresentInferenceResult(std::vector& results); /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MFCC data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] mfcc MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). - * - * @return function function to be called providing audio sample and sliding window index. + * @brief Performs the KWS pipeline. + * @param[in,out] ctx pointer to the application context object + * @return struct containing pointer to audio data where ASR should begin + * and how much data to process. **/ - static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, - TfLiteTensor* inputTensor, - size_t cacheSize); + static KWSOutput doKws(ApplicationContext& ctx) + { + auto& profiler = ctx.Get("profiler"); + auto& kwsModel = ctx.Get("kwsModel"); + const auto kwsMfccFrameLength = ctx.Get("kwsFrameLength"); + const auto kwsMfccFrameStride = ctx.Get("kwsFrameStride"); + const auto kwsScoreThreshold = ctx.Get("kwsScoreThreshold"); + + auto currentIndex = ctx.Get("clipIndex"); - /** - * @brief Performs the KWS pipeline. - * @param[in,out] ctx pointer to the application context object - * - * @return KWSOutput struct containing pointer to audio data where ASR should begin - * and how much data to process. - */ - static KWSOutput doKws(ApplicationContext& ctx) { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast( - (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? - arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); + (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? + MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); - KWSOutput output; + /* Output struct from doing KWS. */ + KWSOutput output {}; - auto& profiler = ctx.Get("profiler"); - auto& kwsModel = ctx.Get("kwsmodel"); if (!kwsModel.IsInited()) { printf_err("KWS model has not been initialised\n"); return output; } - const int kwsFrameLength = ctx.Get("kwsframeLength"); - const int kwsFrameStride = ctx.Get("kwsframeStride"); - const float kwsScoreThreshold = ctx.Get("kwsscoreThreshold"); - - TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); - + TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); if (!kwsInputTensor->dims) { printf_err("Invalid input tensor dims\n"); return output; @@ -126,63 +102,32 @@ namespace app { return output; } - const uint32_t kwsNumMfccFeats = ctx.Get("kwsNumMfcc"); - const uint32_t kwsNumAudioWindows = ctx.Get("kwsNumAudioWins"); - - audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength); - kwsMfcc.Init(); - - /* Deduce the data length required for 1 KWS inference from the network parameters. */ - auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride + - (kwsFrameLength - kwsFrameStride); - auto kwsMfccWindowSize = kwsFrameLength; - auto kwsMfccWindowStride = kwsFrameStride; - - /* We are choosing to move by half the window size => for a 1 second window size, - * this means an overlap of 0.5 seconds. */ - auto kwsAudioDataStride = kwsAudioDataWindowSize / 2; - - info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize); - - /* Stride must be multiple of mfcc features window stride to re-use features. */ - if (0 != kwsAudioDataStride % kwsMfccWindowStride) { - kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride; - } - - auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride; + /* Get input shape for feature extraction. */ + TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); + const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; - auto currentIndex = ctx.Get("clipIndex"); + /* Set up pre and post-processing. */ + KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames, + kwsMfccFrameLength, kwsMfccFrameStride); - /* Creating a mfcc features sliding window for the data required for 1 inference. */ - auto kwsAudioMFCCWindowSlider = audio::SlidingWindow( - get_audio_array(currentIndex), - kwsAudioDataWindowSize, kwsMfccWindowSize, - kwsMfccWindowStride); + std::vector singleInfResult; + KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get("kwsClassifier"), + ctx.Get&>("kwsLabels"), + singleInfResult); /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow( get_audio_array(currentIndex), get_audio_array_size(currentIndex), - kwsAudioDataWindowSize, kwsAudioDataStride); - - /* Calculate number of the feature vectors in the window overlap region. - * These feature vectors will be reused.*/ - size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1 - - kwsMfccVectorsInAudioStride; + preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); - auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor, - numberOfReusedFeatureVectors); - - if (!kwsMfccFeatureCalc){ - return output; - } - - /* Container for KWS results. */ - std::vector kwsResults; + /* Declare a container to hold kws results from across the whole audio clip. */ + std::vector finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running KWS inference... "}; @@ -197,70 +142,56 @@ namespace app { while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - /* We moved to the next window - set the features sliding to the new address. */ - kwsAudioMFCCWindowSlider.Reset(inferenceWindow); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (kwsAudioMFCCWindowSlider.HasNext()) { - const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next(); - std::vector kwsMfccAudioData = - std::vector(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize); - - /* Compute features for this window and write them to input tensor. */ - kwsMfccFeatureCalc(kwsMfccAudioData, - kwsAudioMFCCWindowSlider.Index(), - useCache, - kwsMfccVectorsInAudioStride); - } + preProcess.m_audioWindowIndex = audioDataSlider.Index(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, - audioDataSlider.TotalStrides() + 1); + /* Run the pre-processing, inference and post-processing. */ + if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + printf_err("KWS Pre-processing failed."); + return output; + } - /* Run inference over this audio clip sliding window. */ if (!RunInference(kwsModel, profiler)) { - printf_err("KWS inference failed\n"); + printf_err("KWS Inference failed."); return output; } - std::vector kwsClassificationResult; - auto& kwsClassifier = ctx.Get("kwsclassifier"); + if (!postProcess.DoPostProcess()) { + printf_err("KWS Post-processing failed."); + return output; + } - kwsClassifier.GetClassificationResults( - kwsOutputTensor, kwsClassificationResult, - ctx.Get&>("kwslabels"), 1, true); + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); - kwsResults.emplace_back( - kws::KwsResult( - kwsClassificationResult, - audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride, - audioDataSlider.Index(), kwsScoreThreshold) - ); + /* Add results from this window to our final results vector. */ + finalResults.emplace_back( + kws::KwsResult(singleInfResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride, + audioDataSlider.Index(), kwsScoreThreshold)); - /* Keyword detected. */ - if (kwsClassificationResult[0].m_label == ctx.Get("triggerkeyword")) { - output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize; + /* Break out when trigger keyword is detected. */ + if (singleInfResult[0].m_label == ctx.Get("triggerKeyword") + && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { + output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize; output.asrAudioSamples = get_audio_array_size(currentIndex) - (audioDataSlider.NextWindowStartIndex() - - kwsAudioDataStride + kwsAudioDataWindowSize); + preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize); break; } #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(kwsOutputTensor); + DumpTensor(kwsOutputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - if (!PresentInferenceResult(kwsResults)) { + if (!PresentInferenceResult(finalResults)) { return output; } @@ -271,41 +202,41 @@ namespace app { } /** - * @brief Performs the ASR pipeline. - * - * @param[in,out] ctx pointer to the application context object - * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin - * and how much data to process - * @return bool true if pipeline executed without failure - */ - static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { + * @brief Performs the ASR pipeline. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin + * and how much data to process. + * @return true if pipeline executed without failure. + **/ + static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) + { + auto& asrModel = ctx.Get("asrModel"); + auto& profiler = ctx.Get("profiler"); + auto asrMfccFrameLen = ctx.Get("asrFrameLength"); + auto asrMfccFrameStride = ctx.Get("asrFrameStride"); + auto asrScoreThreshold = ctx.Get("asrScoreThreshold"); + auto asrInputCtxLen = ctx.Get("ctxLen"); + constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - auto& profiler = ctx.Get("profiler"); - hal_lcd_clear(COLOR_BLACK); - - /* Get model reference. */ - auto& asrModel = ctx.Get("asrmodel"); if (!asrModel.IsInited()) { printf_err("ASR model has not been initialised\n"); return false; } - /* Get score threshold to be applied for the classifier (post-inference). */ - auto asrScoreThreshold = ctx.Get("asrscoreThreshold"); + hal_lcd_clear(COLOR_BLACK); - /* Dimensions of the tensor should have been verified by the callee. */ + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); - const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - /* Populate ASR MFCC related parameters. */ - auto asrMfccParamsWinLen = ctx.Get("asrframeLength"); - auto asrMfccParamsWinStride = ctx.Get("asrframeStride"); + /* Get input shape. Dimensions of the tensor should have been verified by + * the callee. */ + TfLiteIntArray* inputShape = asrModel.GetInputShape(0); - /* Populate ASR inference context and inner lengths for input. */ - auto asrInputCtxLen = ctx.Get("ctxLen"); + + const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx]; const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); /* Make sure the input tensor supports the above context and inner lengths. */ @@ -316,18 +247,9 @@ namespace app { } /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) * - asrMfccParamsWinStride + (asrMfccParamsWinLen); - const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride; - const float asrAudioParamsSecondsPerSample = - (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); - - /* Get pre/post-processing objects */ - auto& asrPrep = ctx.Get("preprocess"); - auto& asrPostp = ctx.Get("postprocess"); - - /* Set default reduction axis for post-processing. */ - const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); + const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride; + const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; /* Get the remaining audio buffer and respective size from KWS results. */ const int16_t* audioArr = kwsOutput.asrAudioStart; @@ -335,9 +257,9 @@ namespace app { /* Audio clip must have enough samples to produce 1 MFCC feature. */ std::vector audioBuffer = std::vector(audioArr, audioArr + audioArrSize); - if (audioArrSize < asrMfccParamsWinLen) { + if (audioArrSize < asrMfccFrameLen) { printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n", - asrMfccParamsWinLen); + asrMfccFrameLen); return false; } @@ -345,26 +267,38 @@ namespace app { auto audioDataSlider = audio::FractionalSlidingWindow( audioBuffer.data(), audioBuffer.size(), - asrAudioParamsWinLen, - asrAudioParamsWinStride); + asrAudioDataWindowLen, + asrAudioDataWindowStride); /* Declare a container for results. */ - std::vector asrResults; + std::vector asrResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running ASR inference... "}; - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - size_t asrInferenceWindowLen = asrAudioParamsWinLen; - + size_t asrInferenceWindowLen = asrAudioDataWindowLen; + + /* Set up pre and post-processing objects. */ + AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + asrMfccFrameLen, asrMfccFrameStride); + + std::vector singleInfResult; + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen); + AsrPostProcess asrPostProcess = AsrPostProcess( + asrOutputTensor, ctx.Get("asrClassifier"), + ctx.Get&>("asrLabels"), + singleInfResult, outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx + ); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { /* If not enough audio see how much can be sent for processing. */ size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); - if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) { + if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) { asrInferenceWindowLen = audioBuffer.size() - nextStartIndex; } @@ -373,8 +307,11 @@ namespace app { info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, static_cast(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); - /* Calculate MFCCs, deltas and populate the input tensor. */ - asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor); + /* Run the pre-processing, inference and post-processing. */ + if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) { + printf_err("ASR pre-processing failed."); + return false; + } /* Run inference over this audio clip sliding window. */ if (!RunInference(asrModel, profiler)) { @@ -382,24 +319,28 @@ namespace app { return false; } - /* Post-process. */ - asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext()); + /* Post processing needs to know if we are on the last audio window. */ + asrPostProcess.m_lastIteration = !audioDataSlider.HasNext(); + if (!asrPostProcess.DoPostProcess()) { + printf_err("ASR post-processing failed."); + return false; + } /* Get results. */ std::vector asrClassificationResult; - auto& asrClassifier = ctx.Get("asrclassifier"); + auto& asrClassifier = ctx.Get("asrClassifier"); asrClassifier.GetClassificationResults( asrOutputTensor, asrClassificationResult, - ctx.Get&>("asrlabels"), 1); + ctx.Get&>("asrLabels"), 1); asrResults.emplace_back(asr::AsrResult(asrClassificationResult, (audioDataSlider.Index() * asrAudioParamsSecondsPerSample * - asrAudioParamsWinStride), + asrAudioDataWindowStride), audioDataSlider.Index(), asrScoreThreshold)); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); + armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ /* Erase */ @@ -417,7 +358,7 @@ namespace app { return true; } - /* Audio inference classification handler. */ + /* KWS and ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { hal_lcd_clear(COLOR_BLACK); @@ -434,13 +375,14 @@ namespace app { do { KWSOutput kwsOutput = doKws(ctx); if (!kwsOutput.executionSuccess) { + printf_err("KWS failed\n"); return false; } if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { - info("Keyword spotted\n"); + info("Trigger keyword spotted\n"); if(!doAsr(ctx, kwsOutput)) { - printf_err("ASR failed"); + printf_err("ASR failed\n"); return false; } } @@ -452,7 +394,6 @@ namespace app { return true; } - static bool PresentInferenceResult(std::vector& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -464,33 +405,31 @@ namespace app { /* Display each result. */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (uint32_t i = 0; i < results.size(); ++i) { - + for (auto & result : results) { std::string topKeyword{""}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { - topKeyword = results[i].m_resultVec[0].m_label; - score = results[i].m_resultVec[0].m_normalisedVal; + if (!result.m_resultVec.empty()) { + topKeyword = result.m_resultVec[0].m_label; + score = result.m_resultVec[0].m_normalisedVal; } std::string resultStr = - std::string{"@"} + std::to_string(results[i].m_timeStamp) + + std::string{"@"} + std::to_string(result.m_timeStamp) + std::string{"s: "} + topKeyword + std::string{" ("} + std::to_string(static_cast(score * 100)) + std::string{"%)"}; - hal_lcd_display_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); + hal_lcd_display_text(resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); rowIdx1 += dataPsnTxtYIncr; info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n", - results[i].m_timeStamp, results[i].m_inferenceNumber, - results[i].m_threshold); - for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) { + result.m_timeStamp, result.m_inferenceNumber, + result.m_threshold); + for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j, - results[i].m_resultVec[j].m_label.c_str(), - results[i].m_resultVec[j].m_normalisedVal); + result.m_resultVec[j].m_label.c_str(), + result.m_resultVec[j].m_normalisedVal); } } @@ -523,143 +462,12 @@ namespace app { std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text( - finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); info("Final result: %s\n", finalResultStr.c_str()); return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - **/ - template - std::function&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function (std::vector& )> compute) - { - /* Feature cache to be captured by lambda function. */ - static std::vector> featureCache = std::vector>(cacheSize); - - return [=](std::vector& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex) - { - T* tensorData = tflite::GetTensorData(inputTensor); - std::vector features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. - */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size(); - auto sizeBytes = sizeof(T) * size; - std::memcpy(tensorData + (index * size), features.data(), sizeBytes); - - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex) { - featureCache[index - featuresOverlapIndex] = std::move(features); - } - }; - } - - template std::function&, size_t , bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector& )> compute); - - template std::function&, size_t , bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector& )> compute); - - template std::function&, size_t , bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector& )> compute); - - template std::function&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function(std::vector&)> compute); - - - static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) - { - std::function&, size_t, bool, size_t)> mfccFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto* quantParams = (TfLiteAffineQuantization*) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &mfcc](std::vector& audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteUInt8: { - mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &mfcc](std::vector& audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteInt16: { - mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &mfcc](std::vector& audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - - - } else { - mfccFeatureCalc = mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [&mfcc](std::vector& audioDataWindow) { - return mfcc.MfccCompute(audioDataWindow); - }); - } - return mfccFeatureCalc; - } } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc index 2a76b1b..42f434e 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,62 +15,71 @@ * limitations under the License. */ #include "Wav2LetterPostprocess.hpp" + #include "Wav2LetterModel.hpp" #include "log_macros.h" +#include + namespace arm { namespace app { -namespace audio { -namespace asr { - - Postprocess::Postprocess(const uint32_t contextLen, - const uint32_t innerLen, - const uint32_t blankTokenIdx) - : m_contextLen(contextLen), - m_innerLen(innerLen), - m_totalLen(2 * this->m_contextLen + this->m_innerLen), + + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector& labels, std::vector& results, + const uint32_t outputContextLen, + const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx + ): + m_classifier(classifier), + m_outputTensor(outputTensor), + m_labels{labels}, + m_results(results), + m_outputContextLen(outputContextLen), m_countIterations(0), - m_blankTokenIdx(blankTokenIdx) - {} + m_blankTokenIdx(blankTokenIdx), + m_reductionAxisIdx(reductionAxisIdx) + { + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); + } - bool Postprocess::Invoke(TfLiteTensor* tensor, - const uint32_t axisIdx, - const bool lastIteration) + bool AsrPostProcess::DoPostProcess() { /* Basic checks. */ - if (!this->IsInputValid(tensor, axisIdx)) { + if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { return false; } /* Irrespective of tensor type, we use unsigned "byte" */ - uint8_t* ptrData = tflite::GetTensorData(tensor); - const uint32_t elemSz = this->GetTensorElementSize(tensor); + auto* ptrData = tflite::GetTensorData(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); /* Other sanity checks. */ if (0 == elemSz) { printf_err("Tensor type not supported for post processing\n"); return false; - } else if (elemSz * this->m_totalLen > tensor->bytes) { + } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) { printf_err("Insufficient number of tensor bytes\n"); return false; } /* Which axis do we need to process? */ - switch (axisIdx) { - case arm::app::Wav2LetterModel::ms_outputRowsIdx: - return this->EraseSectionsRowWise(ptrData, - elemSz * - tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx], - lastIteration); + switch (this->m_reductionAxisIdx) { + case Wav2LetterModel::ms_outputRowsIdx: + this->EraseSectionsRowWise( + ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx], + this->m_lastIteration); + break; default: - printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx); + printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx); + return false; } + this->m_classifier.GetClassificationResults(this->m_outputTensor, + this->m_results, this->m_labels, 1); - return false; + return true; } - bool Postprocess::IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const { if (nullptr == tensor) { return false; @@ -84,25 +93,23 @@ namespace asr { if (static_cast(this->m_totalLen) != tensor->dims->data[axisIdx]) { - printf_err("Unexpected tensor dimension for axis %d, \n", - tensor->dims->data[axisIdx]); + printf_err("Unexpected tensor dimension for axis %d, got %d, \n", + axisIdx, tensor->dims->data[axisIdx]); return false; } return true; } - uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor) + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) { switch(tensor->type) { case kTfLiteUInt8: - return 1; case kTfLiteInt8: return 1; case kTfLiteInt16: return 2; case kTfLiteInt32: - return 4; case kTfLiteFloat32: return 4; default: @@ -113,30 +120,30 @@ namespace asr { return 0; } - bool Postprocess::EraseSectionsRowWise( - uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration) + bool AsrPostProcess::EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) { /* In this case, the "zero-ing" is quite simple as the region * to be zeroed sits in contiguous memory (row-major). */ - const uint32_t eraseLen = strideSzBytes * this->m_contextLen; + const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen; /* Erase left context? */ if (this->m_countIterations > 0) { /* Set output of each classification window to the blank token. */ std::memset(ptrData, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } /* Erase right context? */ if (false == lastIteration) { - uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen)); + uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen)); /* Set output of each classification window to the blank token. */ std::memset(rightCtxPtr, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } @@ -150,7 +157,58 @@ namespace asr { return true; } -} /* namespace asr */ -} /* namespace audio */ + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) + { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); + if (inputRows == 0) { + printf_err("Error getting number of input rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_inputRowsIdx); + } + return inputRows; + } + + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + { + const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + } + + /* Watching for underflow. */ + int innerLen = (outputRows - (2 * outputCtxLen)); + + return std::max(innerLen, 0); + } + + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + { + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + return 0; + } + + const float inOutRowRatio = static_cast(inputRows) / + static_cast(outputRows); + + return std::round(static_cast(inputCtxLen) / inOutRowRatio); + } + } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc index d3f3579..92b0631 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,41 +20,35 @@ #include "TensorFlowLiteMicro.hpp" #include -#include +#include namespace arm { namespace app { -namespace audio { -namespace asr { - - Preprocess::Preprocess( - const uint32_t numMfccFeatures, - const uint32_t windowLen, - const uint32_t windowStride, - const uint32_t numMfccVectors): - m_mfcc(numMfccFeatures, windowLen), - m_mfccBuf(numMfccFeatures, numMfccVectors), - m_delta1Buf(numMfccFeatures, numMfccVectors), - m_delta2Buf(numMfccFeatures, numMfccVectors), - m_windowLen(windowLen), - m_windowStride(windowStride), + + AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride + ): + m_mfcc(numMfccFeatures, mfccWindowLen), + m_inputTensor(inputTensor), + m_mfccBuf(numMfccFeatures, numFeatureFrames), + m_delta1Buf(numMfccFeatures, numFeatureFrames), + m_delta2Buf(numMfccFeatures, numFeatureFrames), + m_mfccWindowLen(mfccWindowLen), + m_mfccWindowStride(mfccWindowStride), m_numMfccFeats(numMfccFeatures), - m_numFeatVectors(numMfccVectors), - m_window() + m_numFeatureFrames(numFeatureFrames) { - if (numMfccFeatures > 0 && windowLen > 0) { + if (numMfccFeatures > 0 && mfccWindowLen > 0) { this->m_mfcc.Init(); } } - bool Preprocess::Invoke( - const int16_t* audioData, - const uint32_t audioDataLen, - TfLiteTensor* tensor) + bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) { - this->m_window = SlidingWindow( - audioData, audioDataLen, - this->m_windowLen, this->m_windowStride); + this->m_mfccSlidingWindow = audio::SlidingWindow( + static_cast(audioData), audioDataLen, + this->m_mfccWindowLen, this->m_mfccWindowStride); uint32_t mfccBufIdx = 0; @@ -62,12 +56,12 @@ namespace asr { std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f); std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f); - /* While we can slide over the window. */ - while (this->m_window.HasNext()) { - const int16_t* mfccWindow = this->m_window.Next(); + /* While we can slide over the audio. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); auto mfccAudioData = std::vector( mfccWindow, - mfccWindow + this->m_windowLen); + mfccWindow + this->m_mfccWindowLen); auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData); for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) { this->m_mfccBuf(i, mfccBufIdx) = mfcc[i]; @@ -76,11 +70,11 @@ namespace asr { } /* Pad MFCC if needed by adding MFCC for zeros. */ - if (mfccBufIdx != this->m_numFeatVectors) { - std::vector zerosWindow = std::vector(this->m_windowLen, 0); + if (mfccBufIdx != this->m_numFeatureFrames) { + std::vector zerosWindow = std::vector(this->m_mfccWindowLen, 0); std::vector mfccZeros = this->m_mfcc.MfccCompute(zerosWindow); - while (mfccBufIdx != this->m_numFeatVectors) { + while (mfccBufIdx != this->m_numFeatureFrames) { memcpy(&this->m_mfccBuf(0, mfccBufIdx), mfccZeros.data(), sizeof(float) * m_numMfccFeats); ++mfccBufIdx; @@ -88,41 +82,39 @@ namespace asr { } /* Compute first and second order deltas from MFCCs. */ - this->ComputeDeltas(this->m_mfccBuf, - this->m_delta1Buf, - this->m_delta2Buf); + AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); - /* Normalise. */ - this->Normalise(); + /* Standardize calculated features. */ + this->Standarize(); /* Quantise. */ - QuantParams quantParams = GetTensorQuantParams(tensor); + QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor); if (0 == quantParams.scale) { printf_err("Quantisation scale can't be 0\n"); return false; } - switch(tensor->type) { + switch(this->m_inputTensor->type) { case kTfLiteUInt8: return this->Quantise( - tflite::GetTensorData(tensor), tensor->bytes, + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); case kTfLiteInt8: return this->Quantise( - tflite::GetTensorData(tensor), tensor->bytes, + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); default: printf_err("Unsupported tensor type %s\n", - TfLiteTypeGetName(tensor->type)); + TfLiteTypeGetName(this->m_inputTensor->type)); } return false; } - bool Preprocess::ComputeDeltas(Array2d& mfcc, - Array2d& delta1, - Array2d& delta2) + bool AsrPreProcess::ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2) { const std::vector delta1Coeffs = {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, @@ -148,11 +140,11 @@ namespace asr { /* Iterate through features in MFCC vector. */ for (size_t i = 0; i < numFeatures; ++i) { /* For each feature, iterate through time (t) samples representing feature evolution and - * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels. + * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels. * Convolution padding = valid, result size is `time length - kernel length + 1`. * The result is padded with 0 from both sides to match the size of initial time samples data. * - * For the small filter, conv1d implementation as a simple loop is efficient enough. + * For the small filter, conv1D implementation as a simple loop is efficient enough. * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32. */ @@ -175,20 +167,10 @@ namespace asr { return true; } - float Preprocess::GetMean(Array2d& vec) - { - return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); - } - - float Preprocess::GetStdDev(Array2d& vec, const float mean) - { - return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); - } - - void Preprocess::NormaliseVec(Array2d& vec) + void AsrPreProcess::StandardizeVecF32(Array2d& vec) { - auto mean = Preprocess::GetMean(vec); - auto stddev = Preprocess::GetStdDev(vec, mean); + auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); debug("Mean: %f, Stddev: %f\n", mean, stddev); if (stddev == 0) { @@ -204,14 +186,14 @@ namespace asr { } } - void Preprocess::Normalise() + void AsrPreProcess::Standarize() { - Preprocess::NormaliseVec(this->m_mfccBuf); - Preprocess::NormaliseVec(this->m_delta1Buf); - Preprocess::NormaliseVec(this->m_delta2Buf); + AsrPreProcess::StandardizeVecF32(this->m_mfccBuf); + AsrPreProcess::StandardizeVecF32(this->m_delta1Buf); + AsrPreProcess::StandardizeVecF32(this->m_delta2Buf); } - float Preprocess::GetQuantElem( + float AsrPreProcess::GetQuantElem( const float elem, const float quantScale, const int quantOffset, @@ -222,7 +204,5 @@ namespace asr { return std::min(std::max(val, minVal), maxVal); } -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake index b3fe020..40df4d7 100644 --- a/source/use_case/kws_asr/usecase.cmake +++ b/source/use_case/kws_asr/usecase.cmake @@ -1,5 +1,5 @@ #---------------------------------------------------------------------------- -# Copyright (c) 2021 Arm Limited. All rights reserved. +# Copyright (c) 2021-2022 Arm Limited. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -59,7 +59,7 @@ USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_KWS "Specify the score threshold [0.0, 1.0) that must be applied to the KWS results for a label to be deemed valid." - 0.9 + 0.7 STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [0.0, 1.0) that must be applied to the ASR results for a label to be deemed valid." diff --git a/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp new file mode 100644 index 0000000..cbf0e4e --- /dev/null +++ b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RNNOISE_FEATURE_PROCESSOR_HPP +#define RNNOISE_FEATURE_PROCESSOR_HPP + +#include "PlatformMath.hpp" +#include +#include +#include +#include + +namespace arm { +namespace app { +namespace rnn { + + using vec1D32F = std::vector; + using vec2D32F = std::vector; + using arrHp = std::array; + using math::FftInstance; + using math::FftType; + + class FrameFeatures { + public: + bool m_silence{false}; /* If frame contains silence or not. */ + vec1D32F m_featuresVec{}; /* Calculated feature vector to feed to model. */ + vec1D32F m_fftX{}; /* Vector of floats arranged to represent complex numbers. */ + vec1D32F m_fftP{}; /* Vector of floats arranged to represent complex numbers. */ + vec1D32F m_Ex{}; /* Spectral band energy for audio x. */ + vec1D32F m_Ep{}; /* Spectral band energy for pitch p. */ + vec1D32F m_Exp{}; /* Correlated spectral energy between x and p. */ + }; + + /** + * @brief RNNoise pre and post processing class based on the 2018 paper from + * Jan-Marc Valin. Recommended reading: + * - https://jmvalin.ca/demo/rnnoise/ + * - https://arxiv.org/abs/1709.08243 + **/ + class RNNoiseFeatureProcessor { + /* Public interface */ + public: + RNNoiseFeatureProcessor(); + ~RNNoiseFeatureProcessor() = default; + + /** + * @brief Calculates the features from a given audio buffer ready to be sent to RNNoise model. + * @param[in] audioData Pointer to the floating point vector + * with audio data (within the numerical + * limits of int16_t type). + * @param[in] audioLen Number of elements in the audio window. + * @param[out] features FrameFeatures object reference. + **/ + void PreprocessFrame(const float* audioData, + size_t audioLen, + FrameFeatures& features); + + /** + * @brief Use the RNNoise model output gain values with pre-processing features + * to generate audio with noise suppressed. + * @param[in] modelOutput Output gain values from model. + * @param[in] features Calculated features from pre-processing step. + * @param[out] outFrame Output frame to be populated. + **/ + void PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame); + + + /* Public constants */ + public: + static constexpr uint32_t FRAME_SIZE_SHIFT{2}; + static constexpr uint32_t FRAME_SIZE{512}; + static constexpr uint32_t WINDOW_SIZE{2 * FRAME_SIZE}; + static constexpr uint32_t FREQ_SIZE{FRAME_SIZE + 1}; + + static constexpr uint32_t PITCH_MIN_PERIOD{64}; + static constexpr uint32_t PITCH_MAX_PERIOD{820}; + static constexpr uint32_t PITCH_FRAME_SIZE{1024}; + static constexpr uint32_t PITCH_BUF_SIZE{PITCH_MAX_PERIOD + PITCH_FRAME_SIZE}; + + static constexpr uint32_t NB_BANDS{22}; + static constexpr uint32_t CEPS_MEM{8}; + static constexpr uint32_t NB_DELTA_CEPS{6}; + + static constexpr uint32_t NB_FEATURES{NB_BANDS + 3*NB_DELTA_CEPS + 2}; + + /* Private functions */ + private: + + /** + * @brief Initialises the half window and DCT tables. + */ + void InitTables(); + + /** + * @brief Applies a bi-quadratic filter over the audio window. + * @param[in] bHp Constant coefficient set b (arrHp type). + * @param[in] aHp Constant coefficient set a (arrHp type). + * @param[in,out] memHpX Coefficients populated by this function. + * @param[in,out] audioWindow Floating point vector with audio data. + **/ + void BiQuad( + const arrHp& bHp, + const arrHp& aHp, + arrHp& memHpX, + vec1D32F& audioWindow); + + /** + * @brief Computes features from the "filtered" audio window. + * @param[in] audioWindow Floating point vector with audio data. + * @param[out] features FrameFeatures object reference. + **/ + void ComputeFrameFeatures(vec1D32F& audioWindow, FrameFeatures& features); + + /** + * @brief Runs analysis on the audio buffer. + * @param[in] audioWindow Floating point vector with audio data. + * @param[out] fft Floating point FFT vector containing real and + * imaginary pairs of elements. NOTE: this vector + * does not contain the mirror image (conjugates) + * part of the spectrum. + * @param[out] energy Computed energy for each band in the Bark scale. + * @param[out] analysisMem Buffer sequentially, but partially, + * populated with new audio data. + **/ + void FrameAnalysis( + const vec1D32F& audioWindow, + vec1D32F& fft, + vec1D32F& energy, + vec1D32F& analysisMem); + + /** + * @brief Applies the window function, in-place, over the given + * floating point buffer. + * @param[in,out] x Buffer the window will be applied to. + **/ + void ApplyWindow(vec1D32F& x); + + /** + * @brief Computes the FFT for a given vector. + * @param[in] x Vector to compute the FFT from. + * @param[out] fft Floating point FFT vector containing real and + * imaginary pairs of elements. NOTE: this vector + * does not contain the mirror image (conjugates) + * part of the spectrum. + **/ + void ForwardTransform( + vec1D32F& x, + vec1D32F& fft); + + /** + * @brief Computes band energy for each of the 22 Bark scale bands. + * @param[in] fft_X FFT spectrum (as computed by ForwardTransform). + * @param[out] bandE Vector with 22 elements populated with energy for + * each band. + **/ + void ComputeBandEnergy(const vec1D32F& fft_X, vec1D32F& bandE); + + /** + * @brief Computes band energy correlation. + * @param[in] X FFT vector X. + * @param[in] P FFT vector P. + * @param[out] bandC Vector with 22 elements populated with band energy + * correlation for the two input FFT vectors. + **/ + void ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC); + + /** + * @brief Performs pitch auto-correlation for a given vector for + * given lag. + * @param[in] x Input vector. + * @param[out] ac Auto-correlation output vector. + * @param[in] lag Lag value. + * @param[in] n Number of elements to consider for correlation + * computation. + **/ + void AutoCorr(const vec1D32F &x, + vec1D32F &ac, + size_t lag, + size_t n); + + /** + * @brief Computes pitch cross-correlation. + * @param[in] x Input vector 1. + * @param[in] y Input vector 2. + * @param[out] xCorr Cross-correlation output vector. + * @param[in] len Number of elements to consider for correlation. + * computation. + * @param[in] maxPitch Maximum pitch. + **/ + void PitchXCorr( + const vec1D32F& x, + const vec1D32F& y, + vec1D32F& xCorr, + size_t len, + size_t maxPitch); + + /** + * @brief Computes "Linear Predictor Coefficients". + * @param[in] ac Correlation vector. + * @param[in] p Number of elements of input vector to consider. + * @param[out] lpc Output coefficients vector. + **/ + void LPC(const vec1D32F& ac, int32_t p, vec1D32F& lpc); + + /** + * @brief Custom FIR implementation. + * @param[in] num FIR coefficient vector. + * @param[in] N Number of elements. + * @param[out] x Vector to be be processed. + **/ + void Fir5(const vec1D32F& num, uint32_t N, vec1D32F& x); + + /** + * @brief Down-sample the pitch buffer. + * @param[in,out] pitchBuf Pitch buffer. + * @param[in] pitchBufSz Buffer size. + **/ + void PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz); + + /** + * @brief Pitch search function. + * @param[in] xLP Shifted pitch buffer input. + * @param[in] y Pitch buffer input. + * @param[in] len Length to search for. + * @param[in] maxPitch Maximum pitch. + * @return pitch index. + **/ + int PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch); + + /** + * @brief Finds the "best" pitch from the buffer. + * @param[in] xCorr Pitch correlation vector. + * @param[in] y Pitch buffer input. + * @param[in] len Length to search for. + * @param[in] maxPitch Maximum pitch. + * @return pitch array (2 elements). + **/ + arrHp FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch); + + /** + * @brief Remove pitch period doubling errors. + * @param[in,out] pitchBuf Pitch buffer vector. + * @param[in] maxPeriod Maximum period. + * @param[in] minPeriod Minimum period. + * @param[in] frameSize Frame size. + * @param[in] pitchIdx0_ Pitch index 0. + * @return pitch index. + **/ + int RemoveDoubling( + vec1D32F& pitchBuf, + uint32_t maxPeriod, + uint32_t minPeriod, + uint32_t frameSize, + size_t pitchIdx0_); + + /** + * @brief Computes pitch gain. + * @param[in] xy Single xy cross correlation value. + * @param[in] xx Single xx auto correlation value. + * @param[in] yy Single yy auto correlation value. + * @return Calculated pitch gain. + **/ + float ComputePitchGain(float xy, float xx, float yy); + + /** + * @brief Computes DCT vector from the given input. + * @param[in] input Input vector. + * @param[out] output Output vector with DCT coefficients. + **/ + void DCT(vec1D32F& input, vec1D32F& output); + + /** + * @brief Perform inverse fourier transform on complex spectral vector. + * @param[out] out Output vector. + * @param[in] fftXIn Vector of floats arranged to represent complex numbers interleaved. + **/ + void InverseTransform(vec1D32F& out, vec1D32F& fftXIn); + + /** + * @brief Perform pitch filtering. + * @param[in] features Object with pre-processing calculated frame features. + * @param[in] g Gain values. + **/ + void PitchFilter(FrameFeatures& features, vec1D32F& g); + + /** + * @brief Interpolate the band gain values. + * @param[out] g Gain values. + * @param[in] bandE Vector with 22 elements populated with energy for + * each band. + **/ + void InterpBandGain(vec1D32F& g, vec1D32F& bandE); + + /** + * @brief Create de-noised frame. + * @param[out] outFrame Output vector for storing the created audio frame. + * @param[in] fftY Gain adjusted complex spectral vector. + */ + void FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY); + + /* Private objects */ + private: + FftInstance m_fftInstReal; /* FFT instance for real numbers */ + FftInstance m_fftInstCmplx; /* FFT instance for complex numbers */ + vec1D32F m_halfWindow; /* Window coefficients */ + vec1D32F m_dctTable; /* DCT table */ + vec1D32F m_analysisMem; /* Buffer used for frame analysis */ + vec2D32F m_cepstralMem; /* Cepstral coefficients */ + size_t m_memId; /* memory ID */ + vec1D32F m_synthesisMem; /* Synthesis mem (used by post-processing) */ + vec1D32F m_pitchBuf; /* Pitch buffer */ + float m_lastGain; /* Last gain calculated */ + int m_lastPeriod; /* Last period calculated */ + arrHp m_memHpX; /* HpX coefficients. */ + vec1D32F m_lastGVec; /* Last gain vector (used by post-processing) */ + + /* Constants */ + const std::array m_eband5ms { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, + 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100}; + }; + + +} /* namespace rnn */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* RNNOISE_FEATURE_PROCESSOR_HPP */ diff --git a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp b/source/use_case/noise_reduction/include/RNNoiseProcess.hpp deleted file mode 100644 index c188e42..0000000 --- a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp +++ /dev/null @@ -1,337 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "PlatformMath.hpp" -#include -#include -#include -#include - -namespace arm { -namespace app { -namespace rnn { - - using vec1D32F = std::vector; - using vec2D32F = std::vector; - using arrHp = std::array; - using math::FftInstance; - using math::FftType; - - class FrameFeatures { - public: - bool m_silence{false}; /* If frame contains silence or not. */ - vec1D32F m_featuresVec{}; /* Calculated feature vector to feed to model. */ - vec1D32F m_fftX{}; /* Vector of floats arranged to represent complex numbers. */ - vec1D32F m_fftP{}; /* Vector of floats arranged to represent complex numbers. */ - vec1D32F m_Ex{}; /* Spectral band energy for audio x. */ - vec1D32F m_Ep{}; /* Spectral band energy for pitch p. */ - vec1D32F m_Exp{}; /* Correlated spectral energy between x and p. */ - }; - - /** - * @brief RNNoise pre and post processing class based on the 2018 paper from - * Jan-Marc Valin. Recommended reading: - * - https://jmvalin.ca/demo/rnnoise/ - * - https://arxiv.org/abs/1709.08243 - **/ - class RNNoiseProcess { - /* Public interface */ - public: - RNNoiseProcess(); - ~RNNoiseProcess() = default; - - /** - * @brief Calculates the features from a given audio buffer ready to be sent to RNNoise model. - * @param[in] audioData Pointer to the floating point vector - * with audio data (within the numerical - * limits of int16_t type). - * @param[in] audioLen Number of elements in the audio window. - * @param[out] features FrameFeatures object reference. - **/ - void PreprocessFrame(const float* audioData, - size_t audioLen, - FrameFeatures& features); - - /** - * @brief Use the RNNoise model output gain values with pre-processing features - * to generate audio with noise suppressed. - * @param[in] modelOutput Output gain values from model. - * @param[in] features Calculated features from pre-processing step. - * @param[out] outFrame Output frame to be populated. - **/ - void PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame); - - - /* Public constants */ - public: - static constexpr uint32_t FRAME_SIZE_SHIFT{2}; - static constexpr uint32_t FRAME_SIZE{512}; - static constexpr uint32_t WINDOW_SIZE{2 * FRAME_SIZE}; - static constexpr uint32_t FREQ_SIZE{FRAME_SIZE + 1}; - - static constexpr uint32_t PITCH_MIN_PERIOD{64}; - static constexpr uint32_t PITCH_MAX_PERIOD{820}; - static constexpr uint32_t PITCH_FRAME_SIZE{1024}; - static constexpr uint32_t PITCH_BUF_SIZE{PITCH_MAX_PERIOD + PITCH_FRAME_SIZE}; - - static constexpr uint32_t NB_BANDS{22}; - static constexpr uint32_t CEPS_MEM{8}; - static constexpr uint32_t NB_DELTA_CEPS{6}; - - static constexpr uint32_t NB_FEATURES{NB_BANDS + 3*NB_DELTA_CEPS + 2}; - - /* Private functions */ - private: - - /** - * @brief Initialises the half window and DCT tables. - */ - void InitTables(); - - /** - * @brief Applies a bi-quadratic filter over the audio window. - * @param[in] bHp Constant coefficient set b (arrHp type). - * @param[in] aHp Constant coefficient set a (arrHp type). - * @param[in,out] memHpX Coefficients populated by this function. - * @param[in,out] audioWindow Floating point vector with audio data. - **/ - void BiQuad( - const arrHp& bHp, - const arrHp& aHp, - arrHp& memHpX, - vec1D32F& audioWindow); - - /** - * @brief Computes features from the "filtered" audio window. - * @param[in] audioWindow Floating point vector with audio data. - * @param[out] features FrameFeatures object reference. - **/ - void ComputeFrameFeatures(vec1D32F& audioWindow, FrameFeatures& features); - - /** - * @brief Runs analysis on the audio buffer. - * @param[in] audioWindow Floating point vector with audio data. - * @param[out] fft Floating point FFT vector containing real and - * imaginary pairs of elements. NOTE: this vector - * does not contain the mirror image (conjugates) - * part of the spectrum. - * @param[out] energy Computed energy for each band in the Bark scale. - * @param[out] analysisMem Buffer sequentially, but partially, - * populated with new audio data. - **/ - void FrameAnalysis( - const vec1D32F& audioWindow, - vec1D32F& fft, - vec1D32F& energy, - vec1D32F& analysisMem); - - /** - * @brief Applies the window function, in-place, over the given - * floating point buffer. - * @param[in,out] x Buffer the window will be applied to. - **/ - void ApplyWindow(vec1D32F& x); - - /** - * @brief Computes the FFT for a given vector. - * @param[in] x Vector to compute the FFT from. - * @param[out] fft Floating point FFT vector containing real and - * imaginary pairs of elements. NOTE: this vector - * does not contain the mirror image (conjugates) - * part of the spectrum. - **/ - void ForwardTransform( - vec1D32F& x, - vec1D32F& fft); - - /** - * @brief Computes band energy for each of the 22 Bark scale bands. - * @param[in] fft_X FFT spectrum (as computed by ForwardTransform). - * @param[out] bandE Vector with 22 elements populated with energy for - * each band. - **/ - void ComputeBandEnergy(const vec1D32F& fft_X, vec1D32F& bandE); - - /** - * @brief Computes band energy correlation. - * @param[in] X FFT vector X. - * @param[in] P FFT vector P. - * @param[out] bandC Vector with 22 elements populated with band energy - * correlation for the two input FFT vectors. - **/ - void ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC); - - /** - * @brief Performs pitch auto-correlation for a given vector for - * given lag. - * @param[in] x Input vector. - * @param[out] ac Auto-correlation output vector. - * @param[in] lag Lag value. - * @param[in] n Number of elements to consider for correlation - * computation. - **/ - void AutoCorr(const vec1D32F &x, - vec1D32F &ac, - size_t lag, - size_t n); - - /** - * @brief Computes pitch cross-correlation. - * @param[in] x Input vector 1. - * @param[in] y Input vector 2. - * @param[out] xCorr Cross-correlation output vector. - * @param[in] len Number of elements to consider for correlation. - * computation. - * @param[in] maxPitch Maximum pitch. - **/ - void PitchXCorr( - const vec1D32F& x, - const vec1D32F& y, - vec1D32F& xCorr, - size_t len, - size_t maxPitch); - - /** - * @brief Computes "Linear Predictor Coefficients". - * @param[in] ac Correlation vector. - * @param[in] p Number of elements of input vector to consider. - * @param[out] lpc Output coefficients vector. - **/ - void LPC(const vec1D32F& ac, int32_t p, vec1D32F& lpc); - - /** - * @brief Custom FIR implementation. - * @param[in] num FIR coefficient vector. - * @param[in] N Number of elements. - * @param[out] x Vector to be be processed. - **/ - void Fir5(const vec1D32F& num, uint32_t N, vec1D32F& x); - - /** - * @brief Down-sample the pitch buffer. - * @param[in,out] pitchBuf Pitch buffer. - * @param[in] pitchBufSz Buffer size. - **/ - void PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz); - - /** - * @brief Pitch search function. - * @param[in] xLP Shifted pitch buffer input. - * @param[in] y Pitch buffer input. - * @param[in] len Length to search for. - * @param[in] maxPitch Maximum pitch. - * @return pitch index. - **/ - int PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch); - - /** - * @brief Finds the "best" pitch from the buffer. - * @param[in] xCorr Pitch correlation vector. - * @param[in] y Pitch buffer input. - * @param[in] len Length to search for. - * @param[in] maxPitch Maximum pitch. - * @return pitch array (2 elements). - **/ - arrHp FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch); - - /** - * @brief Remove pitch period doubling errors. - * @param[in,out] pitchBuf Pitch buffer vector. - * @param[in] maxPeriod Maximum period. - * @param[in] minPeriod Minimum period. - * @param[in] frameSize Frame size. - * @param[in] pitchIdx0_ Pitch index 0. - * @return pitch index. - **/ - int RemoveDoubling( - vec1D32F& pitchBuf, - uint32_t maxPeriod, - uint32_t minPeriod, - uint32_t frameSize, - size_t pitchIdx0_); - - /** - * @brief Computes pitch gain. - * @param[in] xy Single xy cross correlation value. - * @param[in] xx Single xx auto correlation value. - * @param[in] yy Single yy auto correlation value. - * @return Calculated pitch gain. - **/ - float ComputePitchGain(float xy, float xx, float yy); - - /** - * @brief Computes DCT vector from the given input. - * @param[in] input Input vector. - * @param[out] output Output vector with DCT coefficients. - **/ - void DCT(vec1D32F& input, vec1D32F& output); - - /** - * @brief Perform inverse fourier transform on complex spectral vector. - * @param[out] out Output vector. - * @param[in] fftXIn Vector of floats arranged to represent complex numbers interleaved. - **/ - void InverseTransform(vec1D32F& out, vec1D32F& fftXIn); - - /** - * @brief Perform pitch filtering. - * @param[in] features Object with pre-processing calculated frame features. - * @param[in] g Gain values. - **/ - void PitchFilter(FrameFeatures& features, vec1D32F& g); - - /** - * @brief Interpolate the band gain values. - * @param[out] g Gain values. - * @param[in] bandE Vector with 22 elements populated with energy for - * each band. - **/ - void InterpBandGain(vec1D32F& g, vec1D32F& bandE); - - /** - * @brief Create de-noised frame. - * @param[out] outFrame Output vector for storing the created audio frame. - * @param[in] fftY Gain adjusted complex spectral vector. - */ - void FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY); - - /* Private objects */ - private: - FftInstance m_fftInstReal; /* FFT instance for real numbers */ - FftInstance m_fftInstCmplx; /* FFT instance for complex numbers */ - vec1D32F m_halfWindow; /* Window coefficients */ - vec1D32F m_dctTable; /* DCT table */ - vec1D32F m_analysisMem; /* Buffer used for frame analysis */ - vec2D32F m_cepstralMem; /* Cepstral coefficients */ - size_t m_memId; /* memory ID */ - vec1D32F m_synthesisMem; /* Synthesis mem (used by post-processing) */ - vec1D32F m_pitchBuf; /* Pitch buffer */ - float m_lastGain; /* Last gain calculated */ - int m_lastPeriod; /* Last period calculated */ - arrHp m_memHpX; /* HpX coefficients. */ - vec1D32F m_lastGVec; /* Last gain vector (used by post-processing) */ - - /* Constants */ - const std::array m_eband5ms { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, - 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100}; - - }; - - -} /* namespace rnn */ -} /* namspace app */ -} /* namespace arm */ diff --git a/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp new file mode 100644 index 0000000..15e62d9 --- /dev/null +++ b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RNNOISE_PROCESSING_HPP +#define RNNOISE_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "RNNoiseFeatureProcessor.hpp" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Noise Reduction use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class RNNoisePreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in/out] featureProcessor RNNoise specific feature extractor object. + * @param[in/out] frameFeatures RNNoise specific features shared between pre & post-processing. + * + **/ + explicit RNNoisePreProcess(TfLiteTensor* inputTensor, + std::shared_ptr featureProcessor, + std::shared_ptr frameFeatures); + + /** + * @brief Should perform pre-processing of 'raw' input audio data and load it into + * TFLite Micro input tensors ready for inference + * @param[in] input Pointer to the data that pre-processing will work on. + * @param[in] inputSize Size of the input data. + * @return true if successful, false otherwise. + **/ + bool DoPreProcess(const void* input, size_t inputSize) override; + + private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ + std::shared_ptr m_featureProcessor; /* RNNoise feature processor shared between pre & post-processing. */ + std::shared_ptr m_frameFeatures; /* RNNoise features shared between pre & post-processing. */ + rnn::vec1D32F m_audioFrame; /* Audio frame cast to FP32 */ + + /** + * @brief Quantize the given features and populate the input Tensor. + * @param[in] inputFeatures Vector of floating point features to quantize. + * @param[in] quantScale Quantization scale for the inputTensor. + * @param[in] quantOffset Quantization offset for the inputTensor. + * @param[in,out] inputTensor TFLite micro tensor to populate. + **/ + static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, + float quantScale, int quantOffset, + TfLiteTensor* inputTensor); + }; + + /** + * @brief Post-processing class for Noise Reduction use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class RNNoisePostProcess : public BasePostProcess { + + public: + /** + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[out] denoisedAudioFrame Vector to store the final denoised audio frame. + * @param[in/out] featureProcessor RNNoise specific feature extractor object. + * @param[in/out] frameFeatures RNNoise specific features shared between pre & post-processing. + **/ + RNNoisePostProcess(TfLiteTensor* outputTensor, + std::vector& denoisedAudioFrame, + std::shared_ptr featureProcessor, + std::shared_ptr frameFeatures); + + /** + * @brief Should perform post-processing of the result of inference then + * populate result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + private: + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + std::vector& m_denoisedAudioFrame; /* Vector to store the final denoised frame. */ + rnn::vec1D32F m_denoisedAudioFrameFloat; /* Internal vector to store the final denoised frame (FP32). */ + std::shared_ptr m_featureProcessor; /* RNNoise feature processor shared between pre & post-processing. */ + std::shared_ptr m_frameFeatures; /* RNNoise features shared between pre & post-processing. */ + std::vector m_modelOutputFloat; /* Internal vector to store de-quantized model output. */ + + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* RNNOISE_PROCESSING_HPP */ \ No newline at end of file diff --git a/source/use_case/noise_reduction/src/MainLoop.cc b/source/use_case/noise_reduction/src/MainLoop.cc index 5fd7823..fd72127 100644 --- a/source/use_case/noise_reduction/src/MainLoop.cc +++ b/source/use_case/noise_reduction/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,12 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "RNNoiseModel.hpp" /* Model class for running inference. */ #include "InputFiles.hpp" /* For input audio clips. */ -#include "RNNoiseProcess.hpp" /* Pre-processing class */ #include "log_macros.h" enum opcodes diff --git a/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc new file mode 100644 index 0000000..036894c --- /dev/null +++ b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc @@ -0,0 +1,892 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RNNoiseFeatureProcessor.hpp" +#include "log_macros.h" + +#include +#include +#include + +namespace arm { +namespace app { +namespace rnn { + +#define VERIFY(x) \ +do { \ + if (!(x)) { \ + printf_err("Assert failed:" #x "\n"); \ + exit(1); \ + } \ +} while(0) + +RNNoiseFeatureProcessor::RNNoiseFeatureProcessor() : + m_halfWindow(FRAME_SIZE, 0), + m_dctTable(NB_BANDS * NB_BANDS), + m_analysisMem(FRAME_SIZE, 0), + m_cepstralMem(CEPS_MEM, vec1D32F(NB_BANDS, 0)), + m_memId{0}, + m_synthesisMem(FRAME_SIZE, 0), + m_pitchBuf(PITCH_BUF_SIZE, 0), + m_lastGain{0.0}, + m_lastPeriod{0}, + m_memHpX{}, + m_lastGVec(NB_BANDS, 0) +{ + constexpr uint32_t numFFt = 2 * FRAME_SIZE; + static_assert(numFFt != 0, "Num FFT can't be 0"); + + math::MathUtils::FftInitF32(numFFt, this->m_fftInstReal, FftType::real); + math::MathUtils::FftInitF32(numFFt, this->m_fftInstCmplx, FftType::complex); + this->InitTables(); +} + +void RNNoiseFeatureProcessor::PreprocessFrame(const float* audioData, + const size_t audioLen, + FrameFeatures& features) +{ + /* Note audioWindow is modified in place */ + const arrHp aHp {-1.99599, 0.99600 }; + const arrHp bHp {-2.00000, 1.00000 }; + + vec1D32F audioWindow{audioData, audioData + audioLen}; + + this->BiQuad(bHp, aHp, this->m_memHpX, audioWindow); + this->ComputeFrameFeatures(audioWindow, features); +} + +void RNNoiseFeatureProcessor::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame) +{ + std::vector outputBands = modelOutput; + std::vector gain(FREQ_SIZE, 0); + + if (!features.m_silence) { + PitchFilter(features, outputBands); + for (size_t i = 0; i < NB_BANDS; i++) { + float alpha = .6f; + outputBands[i] = std::max(outputBands[i], alpha * m_lastGVec[i]); + m_lastGVec[i] = outputBands[i]; + } + InterpBandGain(gain, outputBands); + for (size_t i = 0; i < FREQ_SIZE; i++) { + features.m_fftX[2 * i] *= gain[i]; /* Real. */ + features.m_fftX[2 * i + 1] *= gain[i]; /*imaginary. */ + + } + + } + + FrameSynthesis(outFrame, features.m_fftX); +} + +void RNNoiseFeatureProcessor::InitTables() +{ + constexpr float pi = M_PI; + constexpr float halfPi = M_PI / 2; + constexpr float halfPiOverFrameSz = halfPi/FRAME_SIZE; + + for (uint32_t i = 0; i < FRAME_SIZE; i++) { + const float sinVal = math::MathUtils::SineF32(halfPiOverFrameSz * (i + 0.5f)); + m_halfWindow[i] = math::MathUtils::SineF32(halfPi * sinVal * sinVal); + } + + for (uint32_t i = 0; i < NB_BANDS; i++) { + for (uint32_t j = 0; j < NB_BANDS; j++) { + m_dctTable[i * NB_BANDS + j] = math::MathUtils::CosineF32((i + 0.5f) * j * pi / NB_BANDS); + } + m_dctTable[i * NB_BANDS] *= math::MathUtils::SqrtF32(0.5f); + } +} + +void RNNoiseFeatureProcessor::BiQuad( + const arrHp& bHp, + const arrHp& aHp, + arrHp& memHpX, + vec1D32F& audioWindow) +{ + for (float& audioElement : audioWindow) { + const auto xi = audioElement; + const auto yi = audioElement + memHpX[0]; + memHpX[0] = memHpX[1] + (bHp[0] * xi - aHp[0] * yi); + memHpX[1] = (bHp[1] * xi - aHp[1] * yi); + audioElement = yi; + } +} + +void RNNoiseFeatureProcessor::ComputeFrameFeatures(vec1D32F& audioWindow, + FrameFeatures& features) +{ + this->FrameAnalysis(audioWindow, + features.m_fftX, + features.m_Ex, + this->m_analysisMem); + + float energy = 0.0; + + vec1D32F Ly(NB_BANDS, 0); + vec1D32F p(WINDOW_SIZE, 0); + vec1D32F pitchBuf(PITCH_BUF_SIZE >> 1, 0); + + VERIFY(PITCH_BUF_SIZE >= this->m_pitchBuf.size()); + std::copy_n(this->m_pitchBuf.begin() + FRAME_SIZE, + PITCH_BUF_SIZE - FRAME_SIZE, + this->m_pitchBuf.begin()); + + VERIFY(FRAME_SIZE <= audioWindow.size() && PITCH_BUF_SIZE > FRAME_SIZE); + std::copy_n(audioWindow.begin(), + FRAME_SIZE, + this->m_pitchBuf.begin() + PITCH_BUF_SIZE - FRAME_SIZE); + + this->PitchDownsample(pitchBuf, PITCH_BUF_SIZE); + + VERIFY(pitchBuf.size() > PITCH_MAX_PERIOD/2); + vec1D32F xLp(pitchBuf.size() - PITCH_MAX_PERIOD/2, 0); + std::copy_n(pitchBuf.begin() + PITCH_MAX_PERIOD/2, xLp.size(), xLp.begin()); + + int pitchIdx = this->PitchSearch(xLp, pitchBuf, + PITCH_FRAME_SIZE, (PITCH_MAX_PERIOD - (3*PITCH_MIN_PERIOD))); + + pitchIdx = this->RemoveDoubling( + pitchBuf, + PITCH_MAX_PERIOD, + PITCH_MIN_PERIOD, + PITCH_FRAME_SIZE, + PITCH_MAX_PERIOD - pitchIdx); + + size_t stIdx = PITCH_BUF_SIZE - WINDOW_SIZE - pitchIdx; + VERIFY((static_cast(PITCH_BUF_SIZE) - static_cast(WINDOW_SIZE) - pitchIdx) >= 0); + std::copy_n(this->m_pitchBuf.begin() + stIdx, WINDOW_SIZE, p.begin()); + + this->ApplyWindow(p); + this->ForwardTransform(p, features.m_fftP); + this->ComputeBandEnergy(features.m_fftP, features.m_Ep); + this->ComputeBandCorr(features.m_fftX, features.m_fftP, features.m_Exp); + + for (uint32_t i = 0 ; i < NB_BANDS; ++i) { + features.m_Exp[i] /= math::MathUtils::SqrtF32( + 0.001f + features.m_Ex[i] * features.m_Ep[i]); + } + + vec1D32F dctVec(NB_BANDS, 0); + this->DCT(features.m_Exp, dctVec); + + features.m_featuresVec = vec1D32F (NB_FEATURES, 0); + for (uint32_t i = 0; i < NB_DELTA_CEPS; ++i) { + features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS + i] = dctVec[i]; + } + + features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS] -= 1.3; + features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS + 1] -= 0.9; + features.m_featuresVec[NB_BANDS + 3*NB_DELTA_CEPS] = 0.01 * (static_cast(pitchIdx) - 300); + + float logMax = -2.f; + float follow = -2.f; + for (uint32_t i = 0; i < NB_BANDS; ++i) { + Ly[i] = log10f(1e-2f + features.m_Ex[i]); + Ly[i] = std::max(logMax - 7, std::max(follow - 1.5, Ly[i])); + logMax = std::max(logMax, Ly[i]); + follow = std::max(follow - 1.5, Ly[i]); + energy += features.m_Ex[i]; + } + + /* If there's no audio avoid messing up the state. */ + features.m_silence = true; + if (energy < 0.04) { + return; + } else { + features.m_silence = false; + } + + this->DCT(Ly, features.m_featuresVec); + features.m_featuresVec[0] -= 12.0; + features.m_featuresVec[1] -= 4.0; + + VERIFY(CEPS_MEM > 2); + uint32_t stIdx1 = this->m_memId < 1 ? CEPS_MEM + this->m_memId - 1 : this->m_memId - 1; + uint32_t stIdx2 = this->m_memId < 2 ? CEPS_MEM + this->m_memId - 2 : this->m_memId - 2; + VERIFY(stIdx1 < this->m_cepstralMem.size()); + VERIFY(stIdx2 < this->m_cepstralMem.size()); + auto ceps1 = this->m_cepstralMem[stIdx1]; + auto ceps2 = this->m_cepstralMem[stIdx2]; + + /* Ceps 0 */ + for (uint32_t i = 0; i < NB_BANDS; ++i) { + this->m_cepstralMem[this->m_memId][i] = features.m_featuresVec[i]; + } + + for (uint32_t i = 0; i < NB_DELTA_CEPS; ++i) { + features.m_featuresVec[i] = this->m_cepstralMem[this->m_memId][i] + ceps1[i] + ceps2[i]; + features.m_featuresVec[NB_BANDS + i] = this->m_cepstralMem[this->m_memId][i] - ceps2[i]; + features.m_featuresVec[NB_BANDS + NB_DELTA_CEPS + i] = + this->m_cepstralMem[this->m_memId][i] - 2 * ceps1[i] + ceps2[i]; + } + + /* Spectral variability features. */ + this->m_memId += 1; + if (this->m_memId == CEPS_MEM) { + this->m_memId = 0; + } + + float specVariability = 0.f; + + VERIFY(this->m_cepstralMem.size() >= CEPS_MEM); + for (size_t i = 0; i < CEPS_MEM; ++i) { + float minDist = 1e15; + for (size_t j = 0; j < CEPS_MEM; ++j) { + float dist = 0.f; + for (size_t k = 0; k < NB_BANDS; ++k) { + VERIFY(this->m_cepstralMem[i].size() >= NB_BANDS); + auto tmp = this->m_cepstralMem[i][k] - this->m_cepstralMem[j][k]; + dist += tmp * tmp; + } + + if (j != i) { + minDist = std::min(minDist, dist); + } + } + specVariability += minDist; + } + + VERIFY(features.m_featuresVec.size() >= NB_BANDS + 3 * NB_DELTA_CEPS + 1); + features.m_featuresVec[NB_BANDS + 3 * NB_DELTA_CEPS + 1] = specVariability / CEPS_MEM - 2.1; +} + +void RNNoiseFeatureProcessor::FrameAnalysis( + const vec1D32F& audioWindow, + vec1D32F& fft, + vec1D32F& energy, + vec1D32F& analysisMem) +{ + vec1D32F x(WINDOW_SIZE, 0); + + /* Move old audio down and populate end with latest audio window. */ + VERIFY(x.size() >= FRAME_SIZE && analysisMem.size() >= FRAME_SIZE); + VERIFY(audioWindow.size() >= FRAME_SIZE); + + std::copy_n(analysisMem.begin(), FRAME_SIZE, x.begin()); + std::copy_n(audioWindow.begin(), x.size() - FRAME_SIZE, x.begin() + FRAME_SIZE); + std::copy_n(audioWindow.begin(), FRAME_SIZE, analysisMem.begin()); + + this->ApplyWindow(x); + + /* Calculate FFT. */ + ForwardTransform(x, fft); + + /* Compute band energy. */ + ComputeBandEnergy(fft, energy); +} + +void RNNoiseFeatureProcessor::ApplyWindow(vec1D32F& x) +{ + if (WINDOW_SIZE != x.size()) { + printf_err("Invalid size for vector to be windowed\n"); + return; + } + + VERIFY(this->m_halfWindow.size() >= FRAME_SIZE); + + /* Multiply input by sinusoidal function. */ + for (size_t i = 0; i < FRAME_SIZE; i++) { + x[i] *= this->m_halfWindow[i]; + x[WINDOW_SIZE - 1 - i] *= this->m_halfWindow[i]; + } +} + +void RNNoiseFeatureProcessor::ForwardTransform( + vec1D32F& x, + vec1D32F& fft) +{ + /* The input vector can be modified by the fft function. */ + fft.reserve(x.size() + 2); + fft.resize(x.size() + 2, 0); + math::MathUtils::FftF32(x, fft, this->m_fftInstReal); + + /* Normalise. */ + for (auto& f : fft) { + f /= this->m_fftInstReal.m_fftLen; + } + + /* Place the last freq element correctly */ + fft[fft.size()-2] = fft[1]; + fft[1] = 0; + + /* NOTE: We don't truncate out FFT vector as it already contains only the + * first half of the FFT's. The conjugates are not present. */ +} + +void RNNoiseFeatureProcessor::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE) +{ + bandE = vec1D32F(NB_BANDS, 0); + + VERIFY(this->m_eband5ms.size() >= NB_BANDS); + for (uint32_t i = 0; i < NB_BANDS - 1; i++) { + const auto bandSize = (this->m_eband5ms[i + 1] - this->m_eband5ms[i]) + << FRAME_SIZE_SHIFT; + + for (uint32_t j = 0; j < bandSize; j++) { + const auto frac = static_cast(j) / bandSize; + const auto idx = (this->m_eband5ms[i] << FRAME_SIZE_SHIFT) + j; + + auto tmp = fftX[2 * idx] * fftX[2 * idx]; /* Real part */ + tmp += fftX[2 * idx + 1] * fftX[2 * idx + 1]; /* Imaginary part */ + + bandE[i] += (1 - frac) * tmp; + bandE[i + 1] += frac * tmp; + } + } + bandE[0] *= 2; + bandE[NB_BANDS - 1] *= 2; +} + +void RNNoiseFeatureProcessor::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC) +{ + bandC = vec1D32F(NB_BANDS, 0); + VERIFY(this->m_eband5ms.size() >= NB_BANDS); + + for (uint32_t i = 0; i < NB_BANDS - 1; i++) { + const auto bandSize = (this->m_eband5ms[i + 1] - this->m_eband5ms[i]) << FRAME_SIZE_SHIFT; + + for (uint32_t j = 0; j < bandSize; j++) { + const auto frac = static_cast(j) / bandSize; + const auto idx = (this->m_eband5ms[i] << FRAME_SIZE_SHIFT) + j; + + auto tmp = X[2 * idx] * P[2 * idx]; /* Real part */ + tmp += X[2 * idx + 1] * P[2 * idx + 1]; /* Imaginary part */ + + bandC[i] += (1 - frac) * tmp; + bandC[i + 1] += frac * tmp; + } + } + bandC[0] *= 2; + bandC[NB_BANDS - 1] *= 2; +} + +void RNNoiseFeatureProcessor::DCT(vec1D32F& input, vec1D32F& output) +{ + VERIFY(this->m_dctTable.size() >= NB_BANDS * NB_BANDS); + for (uint32_t i = 0; i < NB_BANDS; ++i) { + float sum = 0; + + for (uint32_t j = 0, k = 0; j < NB_BANDS; ++j, k += NB_BANDS) { + sum += input[j] * this->m_dctTable[k + i]; + } + output[i] = sum * math::MathUtils::SqrtF32(2.0/22); + } +} + +void RNNoiseFeatureProcessor::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) { + for (size_t i = 1; i < (pitchBufSz >> 1); ++i) { + pitchBuf[i] = 0.5 * ( + 0.5 * (this->m_pitchBuf[2 * i - 1] + this->m_pitchBuf[2 * i + 1]) + + this->m_pitchBuf[2 * i]); + } + + pitchBuf[0] = 0.5*(0.5*(this->m_pitchBuf[1]) + this->m_pitchBuf[0]); + + vec1D32F ac(5, 0); + size_t numLags = 4; + + this->AutoCorr(pitchBuf, ac, numLags, pitchBufSz >> 1); + + /* Noise floor -40db */ + ac[0] *= 1.0001; + + /* Lag windowing. */ + for (size_t i = 1; i < numLags + 1; ++i) { + ac[i] -= ac[i] * (0.008 * i) * (0.008 * i); + } + + vec1D32F lpc(numLags, 0); + this->LPC(ac, numLags, lpc); + + float tmp = 1.0; + for (size_t i = 0; i < numLags; ++i) { + tmp = 0.9f * tmp; + lpc[i] = lpc[i] * tmp; + } + + vec1D32F lpc2(numLags + 1, 0); + float c1 = 0.8; + + /* Add a zero. */ + lpc2[0] = lpc[0] + 0.8; + lpc2[1] = lpc[1] + (c1 * lpc[0]); + lpc2[2] = lpc[2] + (c1 * lpc[1]); + lpc2[3] = lpc[3] + (c1 * lpc[2]); + lpc2[4] = (c1 * lpc[3]); + + this->Fir5(lpc2, pitchBufSz >> 1, pitchBuf); +} + +int RNNoiseFeatureProcessor::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) { + uint32_t lag = len + maxPitch; + vec1D32F xLp4(len >> 2, 0); + vec1D32F yLp4(lag >> 2, 0); + vec1D32F xCorr(maxPitch >> 1, 0); + + /* Downsample by 2 again. */ + for (size_t j = 0; j < (len >> 2); ++j) { + xLp4[j] = xLp[2*j]; + } + for (size_t j = 0; j < (lag >> 2); ++j) { + yLp4[j] = y[2*j]; + } + + this->PitchXCorr(xLp4, yLp4, xCorr, len >> 2, maxPitch >> 2); + + /* Coarse search with 4x decimation. */ + arrHp bestPitch = this->FindBestPitch(xCorr, yLp4, len >> 2, maxPitch >> 2); + + /* Finer search with 2x decimation. */ + const int maxIdx = (maxPitch >> 1); + for (int i = 0; i < maxIdx; ++i) { + xCorr[i] = 0; + if (std::abs(i - 2*bestPitch[0]) > 2 and std::abs(i - 2*bestPitch[1]) > 2) { + continue; + } + float sum = 0; + for (size_t j = 0; j < len >> 1; ++j) { + sum += xLp[j] * y[i+j]; + } + + xCorr[i] = std::max(-1.0f, sum); + } + + bestPitch = this->FindBestPitch(xCorr, y, len >> 1, maxPitch >> 1); + + int offset; + /* Refine by pseudo-interpolation. */ + if ( 0 < bestPitch[0] && bestPitch[0] < ((maxPitch >> 1) - 1)) { + float a = xCorr[bestPitch[0] - 1]; + float b = xCorr[bestPitch[0]]; + float c = xCorr[bestPitch[0] + 1]; + + if ( (c-a) > 0.7*(b-a) ) { + offset = 1; + } else if ( (a-c) > 0.7*(b-c) ) { + offset = -1; + } else { + offset = 0; + } + } else { + offset = 0; + } + + return 2*bestPitch[0] - offset; +} + +arrHp RNNoiseFeatureProcessor::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch) +{ + float Syy = 1; + arrHp bestNum {-1, -1}; + arrHp bestDen {0, 0}; + arrHp bestPitch {0, 1}; + + for (size_t j = 0; j < len; ++j) { + Syy += (y[j] * y[j]); + } + + for (size_t i = 0; i < maxPitch; ++i ) { + if (xCorr[i] > 0) { + float xCorr16 = xCorr[i] * 1e-12f; /* Avoid problems when squaring. */ + + float num = xCorr16 * xCorr16; + if (num*bestDen[1] > bestNum[1]*Syy) { + if (num*bestDen[0] > bestNum[0]*Syy) { + bestNum[1] = bestNum[0]; + bestDen[1] = bestDen[0]; + bestPitch[1] = bestPitch[0]; + bestNum[0] = num; + bestDen[0] = Syy; + bestPitch[0] = i; + } else { + bestNum[1] = num; + bestDen[1] = Syy; + bestPitch[1] = i; + } + } + } + + Syy += (y[i+len]*y[i+len]) - (y[i]*y[i]); + Syy = std::max(1.0f, Syy); + } + + return bestPitch; +} + +int RNNoiseFeatureProcessor::RemoveDoubling( + vec1D32F& pitchBuf, + uint32_t maxPeriod, + uint32_t minPeriod, + uint32_t frameSize, + size_t pitchIdx0_) +{ + constexpr std::array secondCheck {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}; + uint32_t minPeriod0 = minPeriod; + float lastPeriod = static_cast(this->m_lastPeriod)/2; + float lastGain = static_cast(this->m_lastGain); + + maxPeriod /= 2; + minPeriod /= 2; + pitchIdx0_ /= 2; + frameSize /= 2; + uint32_t xStart = maxPeriod; + + if (pitchIdx0_ >= maxPeriod) { + pitchIdx0_ = maxPeriod - 1; + } + + size_t pitchIdx = pitchIdx0_; + const size_t pitchIdx0 = pitchIdx0_; + + float xx = 0; + for ( size_t i = xStart; i < xStart+frameSize; ++i) { + xx += (pitchBuf[i] * pitchBuf[i]); + } + + float xy = 0; + for ( size_t i = xStart; i < xStart+frameSize; ++i) { + xy += (pitchBuf[i] * pitchBuf[i-pitchIdx0]); + } + + vec1D32F yyLookup (maxPeriod+1, 0); + yyLookup[0] = xx; + float yy = xx; + + for ( size_t i = 1; i < yyLookup.size(); ++i) { + yy = yy + (pitchBuf[xStart-i] * pitchBuf[xStart-i]) - + (pitchBuf[xStart+frameSize-i] * pitchBuf[xStart+frameSize-i]); + yyLookup[i] = std::max(0.0f, yy); + } + + yy = yyLookup[pitchIdx0]; + float bestXy = xy; + float bestYy = yy; + + float g = this->ComputePitchGain(xy, xx, yy); + float g0 = g; + + /* Look for any pitch at pitchIndex/k. */ + for ( size_t k = 2; k < 16; ++k) { + size_t pitchIdx1 = (2*pitchIdx0+k) / (2*k); + if (pitchIdx1 < minPeriod) { + break; + } + + size_t pitchIdx1b; + /* Look for another strong correlation at T1b. */ + if (k == 2) { + if ((pitchIdx1 + pitchIdx0) > maxPeriod) { + pitchIdx1b = pitchIdx0; + } else { + pitchIdx1b = pitchIdx0 + pitchIdx1; + } + } else { + pitchIdx1b = (2*(secondCheck[k])*pitchIdx0 + k) / (2*k); + } + + xy = 0; + for ( size_t i = xStart; i < xStart+frameSize; ++i) { + xy += (pitchBuf[i] * pitchBuf[i-pitchIdx1]); + } + + float xy2 = 0; + for ( size_t i = xStart; i < xStart+frameSize; ++i) { + xy2 += (pitchBuf[i] * pitchBuf[i-pitchIdx1b]); + } + xy = 0.5f * (xy + xy2); + VERIFY(pitchIdx1b < maxPeriod+1); + yy = 0.5f * (yyLookup[pitchIdx1] + yyLookup[pitchIdx1b]); + + float g1 = this->ComputePitchGain(xy, xx, yy); + + float cont; + if (std::abs(pitchIdx1-lastPeriod) <= 1) { + cont = lastGain; + } else if (std::abs(pitchIdx1-lastPeriod) <= 2 and 5*k*k < pitchIdx0) { + cont = 0.5f*lastGain; + } else { + cont = 0.0f; + } + + float thresh = std::max(0.3, 0.7*g0-cont); + + /* Bias against very high pitch (very short period) to avoid false-positives + * due to short-term correlation */ + if (pitchIdx1 < 3*minPeriod) { + thresh = std::max(0.4, 0.85*g0-cont); + } else if (pitchIdx1 < 2*minPeriod) { + thresh = std::max(0.5, 0.9*g0-cont); + } + if (g1 > thresh) { + bestXy = xy; + bestYy = yy; + pitchIdx = pitchIdx1; + g = g1; + } + } + + bestXy = std::max(0.0f, bestXy); + float pg; + if (bestYy <= bestXy) { + pg = 1.0; + } else { + pg = bestXy/(bestYy+1); + } + + std::array xCorr {0}; + for ( size_t k = 0; k < 3; ++k ) { + for ( size_t i = xStart; i < xStart+frameSize; ++i) { + xCorr[k] += (pitchBuf[i] * pitchBuf[i-(pitchIdx+k-1)]); + } + } + + size_t offset; + if ((xCorr[2]-xCorr[0]) > 0.7*(xCorr[1]-xCorr[0])) { + offset = 1; + } else if ((xCorr[0]-xCorr[2]) > 0.7*(xCorr[1]-xCorr[2])) { + offset = -1; + } else { + offset = 0; + } + + if (pg > g) { + pg = g; + } + + pitchIdx0_ = 2*pitchIdx + offset; + + if (pitchIdx0_ < minPeriod0) { + pitchIdx0_ = minPeriod0; + } + + this->m_lastPeriod = pitchIdx0_; + this->m_lastGain = pg; + + return this->m_lastPeriod; +} + +float RNNoiseFeatureProcessor::ComputePitchGain(float xy, float xx, float yy) +{ + return xy / math::MathUtils::SqrtF32(1+xx*yy); +} + +void RNNoiseFeatureProcessor::AutoCorr( + const vec1D32F& x, + vec1D32F& ac, + size_t lag, + size_t n) +{ + if (n < lag) { + printf_err("Invalid parameters for AutoCorr\n"); + return; + } + + auto fastN = n - lag; + + /* Auto-correlation - can be done by PlatformMath functions */ + this->PitchXCorr(x, x, ac, fastN, lag + 1); + + /* Modify auto-correlation by summing with auto-correlation for different lags. */ + for (size_t k = 0; k < lag + 1; k++) { + float d = 0; + for (size_t i = k + fastN; i < n; i++) { + d += x[i] * x[i - k]; + } + ac[k] += d; + } +} + + +void RNNoiseFeatureProcessor::PitchXCorr( + const vec1D32F& x, + const vec1D32F& y, + vec1D32F& xCorr, + size_t len, + size_t maxPitch) +{ + for (size_t i = 0; i < maxPitch; i++) { + float sum = 0; + for (size_t j = 0; j < len; j++) { + sum += x[j] * y[i + j]; + } + xCorr[i] = sum; + } +} + +/* Linear predictor coefficients */ +void RNNoiseFeatureProcessor::LPC( + const vec1D32F& correlation, + int32_t p, + vec1D32F& lpc) +{ + auto error = correlation[0]; + + if (error != 0) { + for (int i = 0; i < p; i++) { + + /* Sum up this iteration's reflection coefficient */ + float rr = 0; + for (int j = 0; j < i; j++) { + rr += lpc[j] * correlation[i - j]; + } + + rr += correlation[i + 1]; + auto r = -rr / error; + + /* Update LP coefficients and total error */ + lpc[i] = r; + for (int j = 0; j < ((i + 1) >> 1); j++) { + auto tmp1 = lpc[j]; + auto tmp2 = lpc[i - 1 - j]; + lpc[j] = tmp1 + (r * tmp2); + lpc[i - 1 - j] = tmp2 + (r * tmp1); + } + + error = error - (r * r * error); + + /* Bail out once we get 30dB gain */ + if (error < (0.001 * correlation[0])) { + break; + } + } + } +} + +void RNNoiseFeatureProcessor::Fir5( + const vec1D32F &num, + uint32_t N, + vec1D32F &x) +{ + auto num0 = num[0]; + auto num1 = num[1]; + auto num2 = num[2]; + auto num3 = num[3]; + auto num4 = num[4]; + auto mem0 = 0; + auto mem1 = 0; + auto mem2 = 0; + auto mem3 = 0; + auto mem4 = 0; + for (uint32_t i = 0; i < N; i++) + { + auto sum_ = x[i] + (num0 * mem0) + (num1 * mem1) + + (num2 * mem2) + (num3 * mem3) + (num4 * mem4); + mem4 = mem3; + mem3 = mem2; + mem2 = mem1; + mem1 = mem0; + mem0 = x[i]; + x[i] = sum_; + } +} + +void RNNoiseFeatureProcessor::PitchFilter(FrameFeatures &features, vec1D32F &gain) { + std::vector r(NB_BANDS, 0); + std::vector rf(FREQ_SIZE, 0); + std::vector newE(NB_BANDS); + + for (size_t i = 0; i < NB_BANDS; i++) { + if (features.m_Exp[i] > gain[i]) { + r[i] = 1; + } else { + + + r[i] = std::pow(features.m_Exp[i], 2) * (1 - std::pow(gain[i], 2)) / + (.001 + std::pow(gain[i], 2) * (1 - std::pow(features.m_Exp[i], 2))); + } + + + r[i] = math::MathUtils::SqrtF32(std::min(1.0f, std::max(0.0f, r[i]))); + r[i] *= math::MathUtils::SqrtF32(features.m_Ex[i] / (1e-8f + features.m_Ep[i])); + } + + InterpBandGain(rf, r); + for (size_t i = 0; i < FREQ_SIZE - 1; i++) { + features.m_fftX[2 * i] += rf[i] * features.m_fftP[2 * i]; /* Real. */ + features.m_fftX[2 * i + 1] += rf[i] * features.m_fftP[2 * i + 1]; /* Imaginary. */ + + } + ComputeBandEnergy(features.m_fftX, newE); + std::vector norm(NB_BANDS); + std::vector normf(FRAME_SIZE, 0); + for (size_t i = 0; i < NB_BANDS; i++) { + norm[i] = math::MathUtils::SqrtF32(features.m_Ex[i] / (1e-8f + newE[i])); + } + + InterpBandGain(normf, norm); + for (size_t i = 0; i < FREQ_SIZE - 1; i++) { + features.m_fftX[2 * i] *= normf[i]; /* Real. */ + features.m_fftX[2 * i + 1] *= normf[i]; /* Imaginary. */ + + } +} + +void RNNoiseFeatureProcessor::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) { + std::vector x(WINDOW_SIZE, 0); + InverseTransform(x, fftY); + ApplyWindow(x); + for (size_t i = 0; i < FRAME_SIZE; i++) { + outFrame[i] = x[i] + m_synthesisMem[i]; + } + memcpy((m_synthesisMem.data()), &x[FRAME_SIZE], FRAME_SIZE*sizeof(float)); +} + +void RNNoiseFeatureProcessor::InterpBandGain(vec1D32F& g, vec1D32F& bandE) { + for (size_t i = 0; i < NB_BANDS - 1; i++) { + int bandSize = (m_eband5ms[i + 1] - m_eband5ms[i]) << FRAME_SIZE_SHIFT; + for (int j = 0; j < bandSize; j++) { + float frac = static_cast(j) / bandSize; + g[(m_eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1 - frac) * bandE[i] + frac * bandE[i + 1]; + } + } +} + +void RNNoiseFeatureProcessor::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) { + + std::vector x(WINDOW_SIZE * 2); /* This is complex. */ + vec1D32F newFFT; /* This is complex. */ + + size_t i; + for (i = 0; i < FREQ_SIZE * 2; i++) { + x[i] = fftXIn[i]; + } + for (i = FREQ_SIZE; i < WINDOW_SIZE; i++) { + x[2 * i] = x[2 * (WINDOW_SIZE - i)]; /* Real. */ + x[2 * i + 1] = -x[2 * (WINDOW_SIZE - i) + 1]; /* Imaginary. */ + } + + constexpr uint32_t numFFt = 2 * FRAME_SIZE; + static_assert(numFFt != 0, "numFFt cannot be 0!"); + + vec1D32F fftOut = vec1D32F(x.size(), 0); + math::MathUtils::FftF32(x,fftOut, m_fftInstCmplx); + + /* Normalize. */ + for (auto &f: fftOut) { + f /= numFFt; + } + + out[0] = WINDOW_SIZE * fftOut[0]; /* Real. */ + for (i = 1; i < WINDOW_SIZE; i++) { + out[i] = WINDOW_SIZE * fftOut[(WINDOW_SIZE * 2) - (2 * i)]; /* Real. */ + } +} + + +} /* namespace rnn */ +} /* namespace app */ +} /* namspace arm */ diff --git a/source/use_case/noise_reduction/src/RNNoiseProcess.cc b/source/use_case/noise_reduction/src/RNNoiseProcess.cc deleted file mode 100644 index 4c568fa..0000000 --- a/source/use_case/noise_reduction/src/RNNoiseProcess.cc +++ /dev/null @@ -1,892 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RNNoiseProcess.hpp" -#include "log_macros.h" - -#include -#include -#include - -namespace arm { -namespace app { -namespace rnn { - -#define VERIFY(x) \ -do { \ - if (!(x)) { \ - printf_err("Assert failed:" #x "\n"); \ - exit(1); \ - } \ -} while(0) - -RNNoiseProcess::RNNoiseProcess() : - m_halfWindow(FRAME_SIZE, 0), - m_dctTable(NB_BANDS * NB_BANDS), - m_analysisMem(FRAME_SIZE, 0), - m_cepstralMem(CEPS_MEM, vec1D32F(NB_BANDS, 0)), - m_memId{0}, - m_synthesisMem(FRAME_SIZE, 0), - m_pitchBuf(PITCH_BUF_SIZE, 0), - m_lastGain{0.0}, - m_lastPeriod{0}, - m_memHpX{}, - m_lastGVec(NB_BANDS, 0) -{ - constexpr uint32_t numFFt = 2 * FRAME_SIZE; - static_assert(numFFt != 0, "Num FFT can't be 0"); - - math::MathUtils::FftInitF32(numFFt, this->m_fftInstReal, FftType::real); - math::MathUtils::FftInitF32(numFFt, this->m_fftInstCmplx, FftType::complex); - this->InitTables(); -} - -void RNNoiseProcess::PreprocessFrame(const float* audioData, - const size_t audioLen, - FrameFeatures& features) -{ - /* Note audioWindow is modified in place */ - const arrHp aHp {-1.99599, 0.99600 }; - const arrHp bHp {-2.00000, 1.00000 }; - - vec1D32F audioWindow{audioData, audioData + audioLen}; - - this->BiQuad(bHp, aHp, this->m_memHpX, audioWindow); - this->ComputeFrameFeatures(audioWindow, features); -} - -void RNNoiseProcess::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame) -{ - std::vector outputBands = modelOutput; - std::vector gain(FREQ_SIZE, 0); - - if (!features.m_silence) { - PitchFilter(features, outputBands); - for (size_t i = 0; i < NB_BANDS; i++) { - float alpha = .6f; - outputBands[i] = std::max(outputBands[i], alpha * m_lastGVec[i]); - m_lastGVec[i] = outputBands[i]; - } - InterpBandGain(gain, outputBands); - for (size_t i = 0; i < FREQ_SIZE; i++) { - features.m_fftX[2 * i] *= gain[i]; /* Real. */ - features.m_fftX[2 * i + 1] *= gain[i]; /*imaginary. */ - - } - - } - - FrameSynthesis(outFrame, features.m_fftX); -} - -void RNNoiseProcess::InitTables() -{ - constexpr float pi = M_PI; - constexpr float halfPi = M_PI / 2; - constexpr float halfPiOverFrameSz = halfPi/FRAME_SIZE; - - for (uint32_t i = 0; i < FRAME_SIZE; i++) { - const float sinVal = math::MathUtils::SineF32(halfPiOverFrameSz * (i + 0.5f)); - m_halfWindow[i] = math::MathUtils::SineF32(halfPi * sinVal * sinVal); - } - - for (uint32_t i = 0; i < NB_BANDS; i++) { - for (uint32_t j = 0; j < NB_BANDS; j++) { - m_dctTable[i * NB_BANDS + j] = math::MathUtils::CosineF32((i + 0.5f) * j * pi / NB_BANDS); - } - m_dctTable[i * NB_BANDS] *= math::MathUtils::SqrtF32(0.5f); - } -} - -void RNNoiseProcess::BiQuad( - const arrHp& bHp, - const arrHp& aHp, - arrHp& memHpX, - vec1D32F& audioWindow) -{ - for (float& audioElement : audioWindow) { - const auto xi = audioElement; - const auto yi = audioElement + memHpX[0]; - memHpX[0] = memHpX[1] + (bHp[0] * xi - aHp[0] * yi); - memHpX[1] = (bHp[1] * xi - aHp[1] * yi); - audioElement = yi; - } -} - -void RNNoiseProcess::ComputeFrameFeatures(vec1D32F& audioWindow, - FrameFeatures& features) -{ - this->FrameAnalysis(audioWindow, - features.m_fftX, - features.m_Ex, - this->m_analysisMem); - - float energy = 0.0; - - vec1D32F Ly(NB_BANDS, 0); - vec1D32F p(WINDOW_SIZE, 0); - vec1D32F pitchBuf(PITCH_BUF_SIZE >> 1, 0); - - VERIFY(PITCH_BUF_SIZE >= this->m_pitchBuf.size()); - std::copy_n(this->m_pitchBuf.begin() + FRAME_SIZE, - PITCH_BUF_SIZE - FRAME_SIZE, - this->m_pitchBuf.begin()); - - VERIFY(FRAME_SIZE <= audioWindow.size() && PITCH_BUF_SIZE > FRAME_SIZE); - std::copy_n(audioWindow.begin(), - FRAME_SIZE, - this->m_pitchBuf.begin() + PITCH_BUF_SIZE - FRAME_SIZE); - - this->PitchDownsample(pitchBuf, PITCH_BUF_SIZE); - - VERIFY(pitchBuf.size() > PITCH_MAX_PERIOD/2); - vec1D32F xLp(pitchBuf.size() - PITCH_MAX_PERIOD/2, 0); - std::copy_n(pitchBuf.begin() + PITCH_MAX_PERIOD/2, xLp.size(), xLp.begin()); - - int pitchIdx = this->PitchSearch(xLp, pitchBuf, - PITCH_FRAME_SIZE, (PITCH_MAX_PERIOD - (3*PITCH_MIN_PERIOD))); - - pitchIdx = this->RemoveDoubling( - pitchBuf, - PITCH_MAX_PERIOD, - PITCH_MIN_PERIOD, - PITCH_FRAME_SIZE, - PITCH_MAX_PERIOD - pitchIdx); - - size_t stIdx = PITCH_BUF_SIZE - WINDOW_SIZE - pitchIdx; - VERIFY((static_cast(PITCH_BUF_SIZE) - static_cast(WINDOW_SIZE) - pitchIdx) >= 0); - std::copy_n(this->m_pitchBuf.begin() + stIdx, WINDOW_SIZE, p.begin()); - - this->ApplyWindow(p); - this->ForwardTransform(p, features.m_fftP); - this->ComputeBandEnergy(features.m_fftP, features.m_Ep); - this->ComputeBandCorr(features.m_fftX, features.m_fftP, features.m_Exp); - - for (uint32_t i = 0 ; i < NB_BANDS; ++i) { - features.m_Exp[i] /= math::MathUtils::SqrtF32( - 0.001f + features.m_Ex[i] * features.m_Ep[i]); - } - - vec1D32F dctVec(NB_BANDS, 0); - this->DCT(features.m_Exp, dctVec); - - features.m_featuresVec = vec1D32F (NB_FEATURES, 0); - for (uint32_t i = 0; i < NB_DELTA_CEPS; ++i) { - features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS + i] = dctVec[i]; - } - - features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS] -= 1.3; - features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS + 1] -= 0.9; - features.m_featuresVec[NB_BANDS + 3*NB_DELTA_CEPS] = 0.01 * (static_cast(pitchIdx) - 300); - - float logMax = -2.f; - float follow = -2.f; - for (uint32_t i = 0; i < NB_BANDS; ++i) { - Ly[i] = log10f(1e-2f + features.m_Ex[i]); - Ly[i] = std::max(logMax - 7, std::max(follow - 1.5, Ly[i])); - logMax = std::max(logMax, Ly[i]); - follow = std::max(follow - 1.5, Ly[i]); - energy += features.m_Ex[i]; - } - - /* If there's no audio avoid messing up the state. */ - features.m_silence = true; - if (energy < 0.04) { - return; - } else { - features.m_silence = false; - } - - this->DCT(Ly, features.m_featuresVec); - features.m_featuresVec[0] -= 12.0; - features.m_featuresVec[1] -= 4.0; - - VERIFY(CEPS_MEM > 2); - uint32_t stIdx1 = this->m_memId < 1 ? CEPS_MEM + this->m_memId - 1 : this->m_memId - 1; - uint32_t stIdx2 = this->m_memId < 2 ? CEPS_MEM + this->m_memId - 2 : this->m_memId - 2; - VERIFY(stIdx1 < this->m_cepstralMem.size()); - VERIFY(stIdx2 < this->m_cepstralMem.size()); - auto ceps1 = this->m_cepstralMem[stIdx1]; - auto ceps2 = this->m_cepstralMem[stIdx2]; - - /* Ceps 0 */ - for (uint32_t i = 0; i < NB_BANDS; ++i) { - this->m_cepstralMem[this->m_memId][i] = features.m_featuresVec[i]; - } - - for (uint32_t i = 0; i < NB_DELTA_CEPS; ++i) { - features.m_featuresVec[i] = this->m_cepstralMem[this->m_memId][i] + ceps1[i] + ceps2[i]; - features.m_featuresVec[NB_BANDS + i] = this->m_cepstralMem[this->m_memId][i] - ceps2[i]; - features.m_featuresVec[NB_BANDS + NB_DELTA_CEPS + i] = - this->m_cepstralMem[this->m_memId][i] - 2 * ceps1[i] + ceps2[i]; - } - - /* Spectral variability features. */ - this->m_memId += 1; - if (this->m_memId == CEPS_MEM) { - this->m_memId = 0; - } - - float specVariability = 0.f; - - VERIFY(this->m_cepstralMem.size() >= CEPS_MEM); - for (size_t i = 0; i < CEPS_MEM; ++i) { - float minDist = 1e15; - for (size_t j = 0; j < CEPS_MEM; ++j) { - float dist = 0.f; - for (size_t k = 0; k < NB_BANDS; ++k) { - VERIFY(this->m_cepstralMem[i].size() >= NB_BANDS); - auto tmp = this->m_cepstralMem[i][k] - this->m_cepstralMem[j][k]; - dist += tmp * tmp; - } - - if (j != i) { - minDist = std::min(minDist, dist); - } - } - specVariability += minDist; - } - - VERIFY(features.m_featuresVec.size() >= NB_BANDS + 3 * NB_DELTA_CEPS + 1); - features.m_featuresVec[NB_BANDS + 3 * NB_DELTA_CEPS + 1] = specVariability / CEPS_MEM - 2.1; -} - -void RNNoiseProcess::FrameAnalysis( - const vec1D32F& audioWindow, - vec1D32F& fft, - vec1D32F& energy, - vec1D32F& analysisMem) -{ - vec1D32F x(WINDOW_SIZE, 0); - - /* Move old audio down and populate end with latest audio window. */ - VERIFY(x.size() >= FRAME_SIZE && analysisMem.size() >= FRAME_SIZE); - VERIFY(audioWindow.size() >= FRAME_SIZE); - - std::copy_n(analysisMem.begin(), FRAME_SIZE, x.begin()); - std::copy_n(audioWindow.begin(), x.size() - FRAME_SIZE, x.begin() + FRAME_SIZE); - std::copy_n(audioWindow.begin(), FRAME_SIZE, analysisMem.begin()); - - this->ApplyWindow(x); - - /* Calculate FFT. */ - ForwardTransform(x, fft); - - /* Compute band energy. */ - ComputeBandEnergy(fft, energy); -} - -void RNNoiseProcess::ApplyWindow(vec1D32F& x) -{ - if (WINDOW_SIZE != x.size()) { - printf_err("Invalid size for vector to be windowed\n"); - return; - } - - VERIFY(this->m_halfWindow.size() >= FRAME_SIZE); - - /* Multiply input by sinusoidal function. */ - for (size_t i = 0; i < FRAME_SIZE; i++) { - x[i] *= this->m_halfWindow[i]; - x[WINDOW_SIZE - 1 - i] *= this->m_halfWindow[i]; - } -} - -void RNNoiseProcess::ForwardTransform( - vec1D32F& x, - vec1D32F& fft) -{ - /* The input vector can be modified by the fft function. */ - fft.reserve(x.size() + 2); - fft.resize(x.size() + 2, 0); - math::MathUtils::FftF32(x, fft, this->m_fftInstReal); - - /* Normalise. */ - for (auto& f : fft) { - f /= this->m_fftInstReal.m_fftLen; - } - - /* Place the last freq element correctly */ - fft[fft.size()-2] = fft[1]; - fft[1] = 0; - - /* NOTE: We don't truncate out FFT vector as it already contains only the - * first half of the FFT's. The conjugates are not present. */ -} - -void RNNoiseProcess::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE) -{ - bandE = vec1D32F(NB_BANDS, 0); - - VERIFY(this->m_eband5ms.size() >= NB_BANDS); - for (uint32_t i = 0; i < NB_BANDS - 1; i++) { - const auto bandSize = (this->m_eband5ms[i + 1] - this->m_eband5ms[i]) - << FRAME_SIZE_SHIFT; - - for (uint32_t j = 0; j < bandSize; j++) { - const auto frac = static_cast(j) / bandSize; - const auto idx = (this->m_eband5ms[i] << FRAME_SIZE_SHIFT) + j; - - auto tmp = fftX[2 * idx] * fftX[2 * idx]; /* Real part */ - tmp += fftX[2 * idx + 1] * fftX[2 * idx + 1]; /* Imaginary part */ - - bandE[i] += (1 - frac) * tmp; - bandE[i + 1] += frac * tmp; - } - } - bandE[0] *= 2; - bandE[NB_BANDS - 1] *= 2; -} - -void RNNoiseProcess::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC) -{ - bandC = vec1D32F(NB_BANDS, 0); - VERIFY(this->m_eband5ms.size() >= NB_BANDS); - - for (uint32_t i = 0; i < NB_BANDS - 1; i++) { - const auto bandSize = (this->m_eband5ms[i + 1] - this->m_eband5ms[i]) << FRAME_SIZE_SHIFT; - - for (uint32_t j = 0; j < bandSize; j++) { - const auto frac = static_cast(j) / bandSize; - const auto idx = (this->m_eband5ms[i] << FRAME_SIZE_SHIFT) + j; - - auto tmp = X[2 * idx] * P[2 * idx]; /* Real part */ - tmp += X[2 * idx + 1] * P[2 * idx + 1]; /* Imaginary part */ - - bandC[i] += (1 - frac) * tmp; - bandC[i + 1] += frac * tmp; - } - } - bandC[0] *= 2; - bandC[NB_BANDS - 1] *= 2; -} - -void RNNoiseProcess::DCT(vec1D32F& input, vec1D32F& output) -{ - VERIFY(this->m_dctTable.size() >= NB_BANDS * NB_BANDS); - for (uint32_t i = 0; i < NB_BANDS; ++i) { - float sum = 0; - - for (uint32_t j = 0, k = 0; j < NB_BANDS; ++j, k += NB_BANDS) { - sum += input[j] * this->m_dctTable[k + i]; - } - output[i] = sum * math::MathUtils::SqrtF32(2.0/22); - } -} - -void RNNoiseProcess::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) { - for (size_t i = 1; i < (pitchBufSz >> 1); ++i) { - pitchBuf[i] = 0.5 * ( - 0.5 * (this->m_pitchBuf[2 * i - 1] + this->m_pitchBuf[2 * i + 1]) - + this->m_pitchBuf[2 * i]); - } - - pitchBuf[0] = 0.5*(0.5*(this->m_pitchBuf[1]) + this->m_pitchBuf[0]); - - vec1D32F ac(5, 0); - size_t numLags = 4; - - this->AutoCorr(pitchBuf, ac, numLags, pitchBufSz >> 1); - - /* Noise floor -40db */ - ac[0] *= 1.0001; - - /* Lag windowing. */ - for (size_t i = 1; i < numLags + 1; ++i) { - ac[i] -= ac[i] * (0.008 * i) * (0.008 * i); - } - - vec1D32F lpc(numLags, 0); - this->LPC(ac, numLags, lpc); - - float tmp = 1.0; - for (size_t i = 0; i < numLags; ++i) { - tmp = 0.9f * tmp; - lpc[i] = lpc[i] * tmp; - } - - vec1D32F lpc2(numLags + 1, 0); - float c1 = 0.8; - - /* Add a zero. */ - lpc2[0] = lpc[0] + 0.8; - lpc2[1] = lpc[1] + (c1 * lpc[0]); - lpc2[2] = lpc[2] + (c1 * lpc[1]); - lpc2[3] = lpc[3] + (c1 * lpc[2]); - lpc2[4] = (c1 * lpc[3]); - - this->Fir5(lpc2, pitchBufSz >> 1, pitchBuf); -} - -int RNNoiseProcess::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) { - uint32_t lag = len + maxPitch; - vec1D32F xLp4(len >> 2, 0); - vec1D32F yLp4(lag >> 2, 0); - vec1D32F xCorr(maxPitch >> 1, 0); - - /* Downsample by 2 again. */ - for (size_t j = 0; j < (len >> 2); ++j) { - xLp4[j] = xLp[2*j]; - } - for (size_t j = 0; j < (lag >> 2); ++j) { - yLp4[j] = y[2*j]; - } - - this->PitchXCorr(xLp4, yLp4, xCorr, len >> 2, maxPitch >> 2); - - /* Coarse search with 4x decimation. */ - arrHp bestPitch = this->FindBestPitch(xCorr, yLp4, len >> 2, maxPitch >> 2); - - /* Finer search with 2x decimation. */ - const int maxIdx = (maxPitch >> 1); - for (int i = 0; i < maxIdx; ++i) { - xCorr[i] = 0; - if (std::abs(i - 2*bestPitch[0]) > 2 and std::abs(i - 2*bestPitch[1]) > 2) { - continue; - } - float sum = 0; - for (size_t j = 0; j < len >> 1; ++j) { - sum += xLp[j] * y[i+j]; - } - - xCorr[i] = std::max(-1.0f, sum); - } - - bestPitch = this->FindBestPitch(xCorr, y, len >> 1, maxPitch >> 1); - - int offset; - /* Refine by pseudo-interpolation. */ - if ( 0 < bestPitch[0] && bestPitch[0] < ((maxPitch >> 1) - 1)) { - float a = xCorr[bestPitch[0] - 1]; - float b = xCorr[bestPitch[0]]; - float c = xCorr[bestPitch[0] + 1]; - - if ( (c-a) > 0.7*(b-a) ) { - offset = 1; - } else if ( (a-c) > 0.7*(b-c) ) { - offset = -1; - } else { - offset = 0; - } - } else { - offset = 0; - } - - return 2*bestPitch[0] - offset; -} - -arrHp RNNoiseProcess::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch) -{ - float Syy = 1; - arrHp bestNum {-1, -1}; - arrHp bestDen {0, 0}; - arrHp bestPitch {0, 1}; - - for (size_t j = 0; j < len; ++j) { - Syy += (y[j] * y[j]); - } - - for (size_t i = 0; i < maxPitch; ++i ) { - if (xCorr[i] > 0) { - float xCorr16 = xCorr[i] * 1e-12f; /* Avoid problems when squaring. */ - - float num = xCorr16 * xCorr16; - if (num*bestDen[1] > bestNum[1]*Syy) { - if (num*bestDen[0] > bestNum[0]*Syy) { - bestNum[1] = bestNum[0]; - bestDen[1] = bestDen[0]; - bestPitch[1] = bestPitch[0]; - bestNum[0] = num; - bestDen[0] = Syy; - bestPitch[0] = i; - } else { - bestNum[1] = num; - bestDen[1] = Syy; - bestPitch[1] = i; - } - } - } - - Syy += (y[i+len]*y[i+len]) - (y[i]*y[i]); - Syy = std::max(1.0f, Syy); - } - - return bestPitch; -} - -int RNNoiseProcess::RemoveDoubling( - vec1D32F& pitchBuf, - uint32_t maxPeriod, - uint32_t minPeriod, - uint32_t frameSize, - size_t pitchIdx0_) -{ - constexpr std::array secondCheck {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}; - uint32_t minPeriod0 = minPeriod; - float lastPeriod = static_cast(this->m_lastPeriod)/2; - float lastGain = static_cast(this->m_lastGain); - - maxPeriod /= 2; - minPeriod /= 2; - pitchIdx0_ /= 2; - frameSize /= 2; - uint32_t xStart = maxPeriod; - - if (pitchIdx0_ >= maxPeriod) { - pitchIdx0_ = maxPeriod - 1; - } - - size_t pitchIdx = pitchIdx0_; - const size_t pitchIdx0 = pitchIdx0_; - - float xx = 0; - for ( size_t i = xStart; i < xStart+frameSize; ++i) { - xx += (pitchBuf[i] * pitchBuf[i]); - } - - float xy = 0; - for ( size_t i = xStart; i < xStart+frameSize; ++i) { - xy += (pitchBuf[i] * pitchBuf[i-pitchIdx0]); - } - - vec1D32F yyLookup (maxPeriod+1, 0); - yyLookup[0] = xx; - float yy = xx; - - for ( size_t i = 1; i < yyLookup.size(); ++i) { - yy = yy + (pitchBuf[xStart-i] * pitchBuf[xStart-i]) - - (pitchBuf[xStart+frameSize-i] * pitchBuf[xStart+frameSize-i]); - yyLookup[i] = std::max(0.0f, yy); - } - - yy = yyLookup[pitchIdx0]; - float bestXy = xy; - float bestYy = yy; - - float g = this->ComputePitchGain(xy, xx, yy); - float g0 = g; - - /* Look for any pitch at pitchIndex/k. */ - for ( size_t k = 2; k < 16; ++k) { - size_t pitchIdx1 = (2*pitchIdx0+k) / (2*k); - if (pitchIdx1 < minPeriod) { - break; - } - - size_t pitchIdx1b; - /* Look for another strong correlation at T1b. */ - if (k == 2) { - if ((pitchIdx1 + pitchIdx0) > maxPeriod) { - pitchIdx1b = pitchIdx0; - } else { - pitchIdx1b = pitchIdx0 + pitchIdx1; - } - } else { - pitchIdx1b = (2*(secondCheck[k])*pitchIdx0 + k) / (2*k); - } - - xy = 0; - for ( size_t i = xStart; i < xStart+frameSize; ++i) { - xy += (pitchBuf[i] * pitchBuf[i-pitchIdx1]); - } - - float xy2 = 0; - for ( size_t i = xStart; i < xStart+frameSize; ++i) { - xy2 += (pitchBuf[i] * pitchBuf[i-pitchIdx1b]); - } - xy = 0.5f * (xy + xy2); - VERIFY(pitchIdx1b < maxPeriod+1); - yy = 0.5f * (yyLookup[pitchIdx1] + yyLookup[pitchIdx1b]); - - float g1 = this->ComputePitchGain(xy, xx, yy); - - float cont; - if (std::abs(pitchIdx1-lastPeriod) <= 1) { - cont = lastGain; - } else if (std::abs(pitchIdx1-lastPeriod) <= 2 and 5*k*k < pitchIdx0) { - cont = 0.5f*lastGain; - } else { - cont = 0.0f; - } - - float thresh = std::max(0.3, 0.7*g0-cont); - - /* Bias against very high pitch (very short period) to avoid false-positives - * due to short-term correlation */ - if (pitchIdx1 < 3*minPeriod) { - thresh = std::max(0.4, 0.85*g0-cont); - } else if (pitchIdx1 < 2*minPeriod) { - thresh = std::max(0.5, 0.9*g0-cont); - } - if (g1 > thresh) { - bestXy = xy; - bestYy = yy; - pitchIdx = pitchIdx1; - g = g1; - } - } - - bestXy = std::max(0.0f, bestXy); - float pg; - if (bestYy <= bestXy) { - pg = 1.0; - } else { - pg = bestXy/(bestYy+1); - } - - std::array xCorr {0}; - for ( size_t k = 0; k < 3; ++k ) { - for ( size_t i = xStart; i < xStart+frameSize; ++i) { - xCorr[k] += (pitchBuf[i] * pitchBuf[i-(pitchIdx+k-1)]); - } - } - - size_t offset; - if ((xCorr[2]-xCorr[0]) > 0.7*(xCorr[1]-xCorr[0])) { - offset = 1; - } else if ((xCorr[0]-xCorr[2]) > 0.7*(xCorr[1]-xCorr[2])) { - offset = -1; - } else { - offset = 0; - } - - if (pg > g) { - pg = g; - } - - pitchIdx0_ = 2*pitchIdx + offset; - - if (pitchIdx0_ < minPeriod0) { - pitchIdx0_ = minPeriod0; - } - - this->m_lastPeriod = pitchIdx0_; - this->m_lastGain = pg; - - return this->m_lastPeriod; -} - -float RNNoiseProcess::ComputePitchGain(float xy, float xx, float yy) -{ - return xy / math::MathUtils::SqrtF32(1+xx*yy); -} - -void RNNoiseProcess::AutoCorr( - const vec1D32F& x, - vec1D32F& ac, - size_t lag, - size_t n) -{ - if (n < lag) { - printf_err("Invalid parameters for AutoCorr\n"); - return; - } - - auto fastN = n - lag; - - /* Auto-correlation - can be done by PlatformMath functions */ - this->PitchXCorr(x, x, ac, fastN, lag + 1); - - /* Modify auto-correlation by summing with auto-correlation for different lags. */ - for (size_t k = 0; k < lag + 1; k++) { - float d = 0; - for (size_t i = k + fastN; i < n; i++) { - d += x[i] * x[i - k]; - } - ac[k] += d; - } -} - - -void RNNoiseProcess::PitchXCorr( - const vec1D32F& x, - const vec1D32F& y, - vec1D32F& xCorr, - size_t len, - size_t maxPitch) -{ - for (size_t i = 0; i < maxPitch; i++) { - float sum = 0; - for (size_t j = 0; j < len; j++) { - sum += x[j] * y[i + j]; - } - xCorr[i] = sum; - } -} - -/* Linear predictor coefficients */ -void RNNoiseProcess::LPC( - const vec1D32F& correlation, - int32_t p, - vec1D32F& lpc) -{ - auto error = correlation[0]; - - if (error != 0) { - for (int i = 0; i < p; i++) { - - /* Sum up this iteration's reflection coefficient */ - float rr = 0; - for (int j = 0; j < i; j++) { - rr += lpc[j] * correlation[i - j]; - } - - rr += correlation[i + 1]; - auto r = -rr / error; - - /* Update LP coefficients and total error */ - lpc[i] = r; - for (int j = 0; j < ((i + 1) >> 1); j++) { - auto tmp1 = lpc[j]; - auto tmp2 = lpc[i - 1 - j]; - lpc[j] = tmp1 + (r * tmp2); - lpc[i - 1 - j] = tmp2 + (r * tmp1); - } - - error = error - (r * r * error); - - /* Bail out once we get 30dB gain */ - if (error < (0.001 * correlation[0])) { - break; - } - } - } -} - -void RNNoiseProcess::Fir5( - const vec1D32F &num, - uint32_t N, - vec1D32F &x) -{ - auto num0 = num[0]; - auto num1 = num[1]; - auto num2 = num[2]; - auto num3 = num[3]; - auto num4 = num[4]; - auto mem0 = 0; - auto mem1 = 0; - auto mem2 = 0; - auto mem3 = 0; - auto mem4 = 0; - for (uint32_t i = 0; i < N; i++) - { - auto sum_ = x[i] + (num0 * mem0) + (num1 * mem1) + - (num2 * mem2) + (num3 * mem3) + (num4 * mem4); - mem4 = mem3; - mem3 = mem2; - mem2 = mem1; - mem1 = mem0; - mem0 = x[i]; - x[i] = sum_; - } -} - -void RNNoiseProcess::PitchFilter(FrameFeatures &features, vec1D32F &gain) { - std::vector r(NB_BANDS, 0); - std::vector rf(FREQ_SIZE, 0); - std::vector newE(NB_BANDS); - - for (size_t i = 0; i < NB_BANDS; i++) { - if (features.m_Exp[i] > gain[i]) { - r[i] = 1; - } else { - - - r[i] = std::pow(features.m_Exp[i], 2) * (1 - std::pow(gain[i], 2)) / - (.001 + std::pow(gain[i], 2) * (1 - std::pow(features.m_Exp[i], 2))); - } - - - r[i] = math::MathUtils::SqrtF32(std::min(1.0f, std::max(0.0f, r[i]))); - r[i] *= math::MathUtils::SqrtF32(features.m_Ex[i] / (1e-8f + features.m_Ep[i])); - } - - InterpBandGain(rf, r); - for (size_t i = 0; i < FREQ_SIZE - 1; i++) { - features.m_fftX[2 * i] += rf[i] * features.m_fftP[2 * i]; /* Real. */ - features.m_fftX[2 * i + 1] += rf[i] * features.m_fftP[2 * i + 1]; /* Imaginary. */ - - } - ComputeBandEnergy(features.m_fftX, newE); - std::vector norm(NB_BANDS); - std::vector normf(FRAME_SIZE, 0); - for (size_t i = 0; i < NB_BANDS; i++) { - norm[i] = math::MathUtils::SqrtF32(features.m_Ex[i] / (1e-8f + newE[i])); - } - - InterpBandGain(normf, norm); - for (size_t i = 0; i < FREQ_SIZE - 1; i++) { - features.m_fftX[2 * i] *= normf[i]; /* Real. */ - features.m_fftX[2 * i + 1] *= normf[i]; /* Imaginary. */ - - } -} - -void RNNoiseProcess::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) { - std::vector x(WINDOW_SIZE, 0); - InverseTransform(x, fftY); - ApplyWindow(x); - for (size_t i = 0; i < FRAME_SIZE; i++) { - outFrame[i] = x[i] + m_synthesisMem[i]; - } - memcpy((m_synthesisMem.data()), &x[FRAME_SIZE], FRAME_SIZE*sizeof(float)); -} - -void RNNoiseProcess::InterpBandGain(vec1D32F& g, vec1D32F& bandE) { - for (size_t i = 0; i < NB_BANDS - 1; i++) { - int bandSize = (m_eband5ms[i + 1] - m_eband5ms[i]) << FRAME_SIZE_SHIFT; - for (int j = 0; j < bandSize; j++) { - float frac = static_cast(j) / bandSize; - g[(m_eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1 - frac) * bandE[i] + frac * bandE[i + 1]; - } - } -} - -void RNNoiseProcess::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) { - - std::vector x(WINDOW_SIZE * 2); /* This is complex. */ - vec1D32F newFFT; /* This is complex. */ - - size_t i; - for (i = 0; i < FREQ_SIZE * 2; i++) { - x[i] = fftXIn[i]; - } - for (i = FREQ_SIZE; i < WINDOW_SIZE; i++) { - x[2 * i] = x[2 * (WINDOW_SIZE - i)]; /* Real. */ - x[2 * i + 1] = -x[2 * (WINDOW_SIZE - i) + 1]; /* Imaginary. */ - } - - constexpr uint32_t numFFt = 2 * FRAME_SIZE; - static_assert(numFFt != 0, "numFFt cannot be 0!"); - - vec1D32F fftOut = vec1D32F(x.size(), 0); - math::MathUtils::FftF32(x,fftOut, m_fftInstCmplx); - - /* Normalize. */ - for (auto &f: fftOut) { - f /= numFFt; - } - - out[0] = WINDOW_SIZE * fftOut[0]; /* Real. */ - for (i = 1; i < WINDOW_SIZE; i++) { - out[i] = WINDOW_SIZE * fftOut[(WINDOW_SIZE * 2) - (2 * i)]; /* Real. */ - } -} - - -} /* namespace rnn */ -} /* namespace app */ -} /* namspace arm */ diff --git a/source/use_case/noise_reduction/src/RNNoiseProcessing.cc b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc new file mode 100644 index 0000000..f6a3ec4 --- /dev/null +++ b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RNNoiseProcessing.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + RNNoisePreProcess::RNNoisePreProcess(TfLiteTensor* inputTensor, + std::shared_ptr featureProcessor, std::shared_ptr frameFeatures) + : m_inputTensor{inputTensor}, + m_featureProcessor{featureProcessor}, + m_frameFeatures{frameFeatures} + {} + + bool RNNoisePreProcess::DoPreProcess(const void* data, size_t inputSize) + { + if (data == nullptr) { + printf_err("Data pointer is null"); + return false; + } + + auto input = static_cast(data); + this->m_audioFrame = rnn::vec1D32F(input, input + inputSize); + m_featureProcessor->PreprocessFrame(this->m_audioFrame.data(), inputSize, *this->m_frameFeatures); + + QuantizeAndPopulateInput(this->m_frameFeatures->m_featuresVec, + this->m_inputTensor->params.scale, this->m_inputTensor->params.zero_point, + this->m_inputTensor); + + debug("Input tensor populated \n"); + + return true; + } + + void RNNoisePreProcess::QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, + const float quantScale, const int quantOffset, + TfLiteTensor* inputTensor) + { + const float minVal = std::numeric_limits::min(); + const float maxVal = std::numeric_limits::max(); + + auto* inputTensorData = tflite::GetTensorData(inputTensor); + + for (size_t i=0; i < inputFeatures.size(); ++i) { + float quantValue = ((inputFeatures[i] / quantScale) + quantOffset); + inputTensorData[i] = static_cast(std::min(std::max(quantValue, minVal), maxVal)); + } + } + + RNNoisePostProcess::RNNoisePostProcess(TfLiteTensor* outputTensor, + std::vector& denoisedAudioFrame, + std::shared_ptr featureProcessor, + std::shared_ptr frameFeatures) + : m_outputTensor{outputTensor}, + m_denoisedAudioFrame{denoisedAudioFrame}, + m_featureProcessor{featureProcessor}, + m_frameFeatures{frameFeatures} + { + this->m_denoisedAudioFrameFloat.reserve(denoisedAudioFrame.size()); + this->m_modelOutputFloat.resize(outputTensor->bytes); + } + + bool RNNoisePostProcess::DoPostProcess() + { + const auto* outputData = tflite::GetTensorData(this->m_outputTensor); + auto outputQuantParams = GetTensorQuantParams(this->m_outputTensor); + + for (size_t i = 0; i < this->m_outputTensor->bytes; ++i) { + this->m_modelOutputFloat[i] = (static_cast(outputData[i]) - outputQuantParams.offset) + * outputQuantParams.scale; + } + + this->m_featureProcessor->PostProcessFrame(this->m_modelOutputFloat, + *this->m_frameFeatures, this->m_denoisedAudioFrameFloat); + + for (size_t i = 0; i < this->m_denoisedAudioFrame.size(); ++i) { + this->m_denoisedAudioFrame[i] = static_cast( + std::roundf(this->m_denoisedAudioFrameFloat[i])); + } + + return true; + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc index acb8ba7..53bb43e 100644 --- a/source/use_case/noise_reduction/src/UseCaseHandler.cc +++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc @@ -21,12 +21,10 @@ #include "ImageUtils.hpp" #include "InputFiles.hpp" #include "RNNoiseModel.hpp" -#include "RNNoiseProcess.hpp" +#include "RNNoiseFeatureProcessor.hpp" +#include "RNNoiseProcessing.hpp" #include "log_macros.h" -#include -#include - namespace arm { namespace app { @@ -36,17 +34,6 @@ namespace app { **/ static void IncrementAppCtxClipIdx(ApplicationContext& ctx); - /** - * @brief Quantize the given features and populate the input Tensor. - * @param[in] inputFeatures Vector of floating point features to quantize. - * @param[in] quantScale Quantization scale for the inputTensor. - * @param[in] quantOffset Quantization offset for the inputTensor. - * @param[in,out] inputTensor TFLite micro tensor to populate. - **/ - static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, - float quantScale, int quantOffset, - TfLiteTensor* inputTensor); - /* Noise reduction inference handler. */ bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll) { @@ -57,7 +44,7 @@ namespace app { size_t memDumpMaxLen = 0; uint8_t* memDumpBaseAddr = nullptr; size_t undefMemDumpBytesWritten = 0; - size_t *pMemDumpBytesWritten = &undefMemDumpBytesWritten; + size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten; if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) { memDumpMaxLen = ctx.Get("MEM_DUMP_LEN"); memDumpBaseAddr = ctx.Get("MEM_DUMP_BASE_ADDR"); @@ -74,8 +61,8 @@ namespace app { } /* Populate Pre-Processing related parameters. */ - auto audioParamsWinLen = ctx.Get("frameLength"); - auto audioParamsWinStride = ctx.Get("frameStride"); + auto audioFrameLen = ctx.Get("frameLength"); + auto audioFrameStride = ctx.Get("frameStride"); auto nrNumInputFeatures = ctx.Get("numInputFeatures"); TfLiteTensor* inputTensor = model.GetInputTensor(0); @@ -103,7 +90,7 @@ namespace app { if (ctx.Has("featureFileNames")) { audioFileAccessorFunc = ctx.Get>("featureFileNames"); } - do{ + do { hal_lcd_clear(COLOR_BLACK); auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten; @@ -112,32 +99,38 @@ namespace app { /* Creating a sliding window through the audio. */ auto audioDataSlider = audio::SlidingWindow( audioAccessorFunc(currentIndex), - audioSizeAccessorFunc(currentIndex), audioParamsWinLen, - audioParamsWinStride); + audioSizeAccessorFunc(currentIndex), audioFrameLen, + audioFrameStride); info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex, audioFileAccessorFunc(currentIndex)); memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex), - (audioDataSlider.TotalStrides() + 1) * audioParamsWinLen, + (audioDataSlider.TotalStrides() + 1) * audioFrameLen, memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); - rnn::RNNoiseProcess featureProcessor = rnn::RNNoiseProcess(); - rnn::vec1D32F audioFrame(audioParamsWinLen); - rnn::vec1D32F inputFeatures(nrNumInputFeatures); - rnn::vec1D32F denoisedAudioFrameFloat(audioParamsWinLen); - std::vector denoisedAudioFrame(audioParamsWinLen); + /* Set up pre and post-processing. */ + std::shared_ptr featureProcessor = + std::make_shared(); + std::shared_ptr frameFeatures = + std::make_shared(); + + RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures); + + std::vector denoisedAudioFrame(audioFrameLen); + RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame, + featureProcessor, frameFeatures); - std::vector modelOutputFloat(outputTensor->bytes); - rnn::FrameFeatures frameFeatures; bool resetGRU = true; while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - audioFrame = rnn::vec1D32F(inferenceWindow, inferenceWindow+audioParamsWinLen); - featureProcessor.PreprocessFrame(audioFrame.data(), audioParamsWinLen, frameFeatures); + if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) { + printf_err("Pre-processing failed."); + return false; + } /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */ if (resetGRU){ @@ -148,53 +141,35 @@ namespace app { model.CopyGruStates(); } - QuantizeAndPopulateInput(frameFeatures.m_featuresVec, - inputTensor->params.scale, inputTensor->params.zero_point, - inputTensor); - /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; /* Display message on the LCD - inference running. */ - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run inference over this feature sliding window. */ - profiler.StartProfiling("Inference"); - bool success = model.RunInference(); - profiler.StopProfiling(); - resetGRU = false; - - if (!success) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } + resetGRU = false; - /* De-quantize main model output ready for post-processing. */ - const auto* outputData = tflite::GetTensorData(outputTensor); - auto outputQuantParams = arm::app::GetTensorQuantParams(outputTensor); - - for (size_t i = 0; i < outputTensor->bytes; ++i) { - modelOutputFloat[i] = (static_cast(outputData[i]) - outputQuantParams.offset) - * outputQuantParams.scale; - } - - /* Round and cast the post-processed results for dumping to wav. */ - featureProcessor.PostProcessFrame(modelOutputFloat, frameFeatures, denoisedAudioFrameFloat); - for (size_t i = 0; i < audioParamsWinLen; ++i) { - denoisedAudioFrame[i] = static_cast(std::roundf(denoisedAudioFrameFloat[i])); + /* Carry out post-processing. */ + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); + return false; } /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); if (memDumpMaxLen > 0) { - /* Dump output tensors to memory. */ + /* Dump final post processed output to memory. */ memDumpBytesWritten += DumpOutputDenoisedAudioFrame( denoisedAudioFrame, memDumpBaseAddr + memDumpBytesWritten, @@ -209,6 +184,7 @@ namespace app { valMemDumpBytesWritten, startDumpAddress); } + /* Finish by dumping the footer. */ DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); info("All inferences for audio clip complete.\n"); @@ -216,15 +192,13 @@ namespace app { IncrementAppCtxClipIdx(ctx); std::string clearString{' '}; - hal_lcd_display_text( - clearString.c_str(), clearString.size(), + hal_lcd_display_text(clearString.c_str(), clearString.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); std::string completeMsg{"Inference complete!"}; /* Display message on the LCD - inference complete. */ - hal_lcd_display_text( - completeMsg.c_str(), completeMsg.size(), + hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); } while (runAll && ctx.Get("clipIndex") != startClipIdx); @@ -233,7 +207,7 @@ namespace app { } size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize, - uint8_t *memAddress, size_t memSize){ + uint8_t* memAddress, size_t memSize){ if (memAddress == nullptr){ return 0; @@ -284,7 +258,7 @@ namespace app { return numBytesWritten; } - size_t DumpDenoisedAudioFooter(uint8_t *memAddress, size_t memSize){ + size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){ if ((memAddress == nullptr) || (memSize < 4)) { return 0; } @@ -294,8 +268,8 @@ namespace app { return sizeof(int32_t); } - size_t DumpOutputDenoisedAudioFrame(const std::vector &audioFrame, - uint8_t *memAddress, size_t memSize) + size_t DumpOutputDenoisedAudioFrame(const std::vector& audioFrame, + uint8_t* memAddress, size_t memSize) { if (memAddress == nullptr) { return 0; @@ -324,7 +298,7 @@ namespace app { const TfLiteTensor* tensor = model.GetOutputTensor(i); const auto* tData = tflite::GetTensorData(tensor); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(tensor); + DumpTensor(tensor); #endif /* VERIFY_TEST_OUTPUT */ /* Ensure that we don't overflow the allowed limit. */ if (numBytesWritten + tensor->bytes <= memSize) { @@ -360,20 +334,5 @@ namespace app { ctx.Set("clipIndex", curClipIdx); } - void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, - const float quantScale, const int quantOffset, TfLiteTensor* inputTensor) - { - const float minVal = std::numeric_limits::min(); - const float maxVal = std::numeric_limits::max(); - - auto* inputTensorData = tflite::GetTensorData(inputTensor); - - for (size_t i=0; i < inputFeatures.size(); ++i) { - float quantValue = ((inputFeatures[i] / quantScale) + quantOffset); - inputTensorData[i] = static_cast(std::min(std::max(quantValue, minVal), maxVal)); - } - } - - } /* namespace app */ } /* namespace arm */ diff --git a/tests/use_case/ad/PostProcessTests.cc b/tests/use_case/ad/PostProcessTests.cc deleted file mode 100644 index 62fa9e7..0000000 --- a/tests/use_case/ad/PostProcessTests.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "AdPostProcessing.hpp" -#include - -TEST_CASE("Softmax_vector") { - - std::vector testVec = {1, 2, 3, 4, 1, 2, 3}; - arm::app::Softmax(testVec); - CHECK((testVec[0] - 0.024) == Approx(0.0).margin(0.001)); - CHECK((testVec[1] - 0.064) == Approx(0.0).margin(0.001)); - CHECK((testVec[2] - 0.175) == Approx(0.0).margin(0.001)); - CHECK((testVec[3] - 0.475) == Approx(0.0).margin(0.001)); - CHECK((testVec[4] - 0.024) == Approx(0.0).margin(0.001)); - CHECK((testVec[5] - 0.064) == Approx(0.0).margin(0.001)); - CHECK((testVec[6] - 0.175) == Approx(0.0).margin(0.001)); -} - -TEST_CASE("Output machine index") { - - auto index = arm::app::OutputIndexFromFileName("test_id_00.wav"); - CHECK(index == 0); - - auto index1 = arm::app::OutputIndexFromFileName("test_id_02.wav"); - CHECK(index1 == 1); - - auto index2 = arm::app::OutputIndexFromFileName("test_id_4.wav"); - CHECK(index2 == 2); - - auto index3 = arm::app::OutputIndexFromFileName("test_id_6.wav"); - CHECK(index3 == 3); - - auto index4 = arm::app::OutputIndexFromFileName("test_id_id_00.wav"); - CHECK(index4 == -1); - - auto index5 = arm::app::OutputIndexFromFileName("test_id_7.wav"); - CHECK(index5 == -1); -} \ No newline at end of file diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc index 3ebdcf4..883c215 100644 --- a/tests/use_case/kws_asr/MfccTests.cc +++ b/tests/use_case/kws_asr/MfccTests.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,13 +93,13 @@ const std::vector testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::MicroNetMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples); } template diff --git a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc index 6fd7df3..e343b66 100644 --- a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,15 +16,17 @@ */ #include "Wav2LetterPostprocess.hpp" #include "Wav2LetterModel.hpp" +#include "ClassificationResult.hpp" #include #include #include template -static TfLiteTensor GetTestTensor(std::vector & shape, - T initVal, - std::vector& vectorBuf) +static TfLiteTensor GetTestTensor( + std::vector& shape, + T initVal, + std::vector& vectorBuf) { REQUIRE(0 != shape.size()); @@ -38,91 +40,112 @@ static TfLiteTensor GetTestTensor(std::vector & shape, vectorBuf = std::vector(sizeInBytes, initVal); TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data()); return tflite::testing::CreateQuantizedTensor( - vectorBuf.data(), dims, - 1, 0, "test-tensor"); + vectorBuf.data(), dims, + 1, 0, "test-tensor"); } TEST_CASE("Checking return value") { SECTION("Mismatched post processing parameters and tensor size") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; std::vector tensorShape = {1, 1, 1, 13}; std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); - REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + tensorShape, 100, tensorVec); + + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; + + REQUIRE(!post.DoPostProcess()); } SECTION("Post processing succeeds") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - - std::vector tensorShape = {1, 1, 13, 1}; - std::vector tensorVec; + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; + std::vector tensorShape = {1, 1, 13, 1}; + std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); } } + TEST_CASE("Postprocessing - erasing required elements") { - constexpr uint32_t ctxLen = 5; + constexpr uint32_t outputCtxLen = 5; constexpr uint32_t innerLen = 3; - constexpr uint32_t nRows = 2*ctxLen + innerLen; + constexpr uint32_t nRows = 2*outputCtxLen + innerLen; constexpr uint32_t nCols = 10; constexpr uint32_t blankTokenIdx = nCols - 1; - std::vector tensorShape = {1, 1, nRows, nCols}; + std::vector tensorShape = {1, 1, nRows, nCols}; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + std::vector dummyResult; SECTION("First and last iteration") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; - TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + std::vector tensorVec; + TfLiteTensor tensor = GetTestTensor(tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vectororiginalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec == tensorVec); } SECTION("Right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; /* This step should erase the right context only. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { - /* Check right context elements are zeroed. */ + /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); } /* Check left context is preserved. */ @@ -131,45 +154,47 @@ TEST_CASE("Postprocessing - erasing required elements") } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { CHECK(tensorVec[i] == originalVec[i]); } } SECTION("Left and right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; - TfLiteTensor tensor = GetTestTensor(tensorShape, 100, tensorVec); + TfLiteTensor tensor = GetTestTensor( + tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* This step should erase right context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The first and last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { /* Check left and right context elements are zeroed. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 1); - CHECK(tensorVec[i * nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 0); - CHECK(tensorVec[i * nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[i*nCols + j] == 0); } } } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { /* Check left context is preserved. */ CHECK(tensorVec[i] == originalVec[i]); } @@ -177,18 +202,21 @@ TEST_CASE("Postprocessing - erasing required elements") SECTION("Try left context erase") { - /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + /* Should not be able to erase the left context if it is the first iteration. */ + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); + REQUIRE(originalVec == tensorVec); } -} \ No newline at end of file +} diff --git a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc index 26ddb24..372152d 100644 --- a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc +++ b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,64 +16,54 @@ */ #include "Wav2LetterPreprocess.hpp" -#include -#include #include +#include constexpr uint32_t numMfccFeatures = 13; constexpr uint32_t numMfccVectors = 10; /* Test vector output: generated using test-asr-preprocessing.py. */ -int8_t expectedResult[numMfccVectors][numMfccFeatures*3] = { - /* Feature vec 0. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, /* Delta 2. */ - - /* Feature vec 1. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 3. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 4 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, - -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15, - - /* Feature vec 5 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, - -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6, - - /* Feature vec 6. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 7. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 8. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 9. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10 +int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = { + /* Feature vec 0. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, /* Delta 2. */ + /* Feature vec 1. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 3. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 4 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, + -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15}, + /* Feature vec 5 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, + -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6}, + /* Feature vec 6. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 7. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 8. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 9. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10} }; void PopulateTestWavVector(std::vector& vec) @@ -97,17 +87,17 @@ void PopulateTestWavVector(std::vector& vec) TEST_CASE("Preprocessing calculation INT8") { - /* Constants. */ - const uint32_t windowLen = 512; - const uint32_t windowStride = 160; - int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; - const float quantScale = 0.1410219967365265; - const int quantOffset = -11; + const uint32_t mfccWindowLen = 512; + const uint32_t mfccWindowStride = 160; + int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; + const float quantScale = 0.1410219967365265; + const int quantOffset = -11; /* Test wav memory. */ - std::vector testWav((windowStride * numMfccVectors) + - (windowLen - windowStride)); + std::vector testWav((mfccWindowStride * numMfccVectors) + + (mfccWindowLen - mfccWindowStride) + ); /* Populate with dummy input. */ PopulateTestWavVector(testWav); @@ -117,20 +107,20 @@ TEST_CASE("Preprocessing calculation INT8") /* Initialise dimensions and the test tensor. */ TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); - TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor( - tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); + TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor( + tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); /* Initialise pre-processing module. */ - arm::app::audio::asr::Preprocess prep{ - numMfccFeatures, windowLen, windowStride, numMfccVectors}; + arm::app::AsrPreProcess prep{&inputTensor, + numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride}; /* Invoke pre-processing. */ - REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor)); + REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size())); /* Wrap the tensor with a std::vector for ease. */ - int8_t * tensorData = tflite::GetTensorData(&tensor); + auto* tensorData = tflite::GetTensorData(&inputTensor); std::vector vecResults = - std::vector(tensorData, tensorData + tensor.bytes); + std::vector(tensorData, tensorData + inputTensor.bytes); /* Check sizes. */ REQUIRE(vecResults.size() == sizeof(expectedResult)); diff --git a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp index e28a6da..ca5aab1 100644 --- a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp +++ b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "RNNoiseProcess.hpp" +#include "RNNoiseFeatureProcessor.hpp" #include #include @@ -208,7 +208,7 @@ TEST_CASE("RNNoise preprocessing calculation test", "[RNNoise]") { SECTION("FP32") { - arm::app::rnn::RNNoiseProcess rnnoiseProcessor; + arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor; arm::app::rnn::FrameFeatures features; rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), features); @@ -223,7 +223,7 @@ TEST_CASE("RNNoise preprocessing calculation test", "[RNNoise]") TEST_CASE("RNNoise postprocessing test", "[RNNoise]") { - arm::app::rnn::RNNoiseProcess rnnoiseProcessor; + arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor; arm::app::rnn::FrameFeatures p; rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), p); std::vector denoised(testWav0.size()); -- cgit v1.2.1