diff options
Diffstat (limited to 'source/use_case')
25 files changed, 1635 insertions, 1208 deletions
diff --git a/source/use_case/ad/include/AdModel.hpp b/source/use_case/ad/include/AdModel.hpp index 8d914c4..2195a7c 100644 --- a/source/use_case/ad/include/AdModel.hpp +++ b/source/use_case/ad/include/AdModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,12 @@ namespace arm { namespace app { class AdModel : public Model { + + public: + /* Indices for the expected model - based on input tensor shape */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + protected: /** @brief Gets the reference to op resolver interface class */ const tflite::MicroOpResolver& GetOpResolver() override; diff --git a/source/use_case/ad/include/AdPostProcessing.hpp b/source/use_case/ad/include/AdPostProcessing.hpp deleted file mode 100644 index 7eaec84..0000000 --- a/source/use_case/ad/include/AdPostProcessing.hpp +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef ADPOSTPROCESSING_HPP -#define ADPOSTPROCESSING_HPP - -#include "TensorFlowLiteMicro.hpp" - -#include <vector> - -namespace arm { -namespace app { - - /** @brief Dequantize TensorFlow Lite Micro tensor. - * @param[in] tensor Pointer to the TensorFlow Lite Micro tensor to be dequantized. - * @return Vector with the dequantized tensor values. - **/ - template<typename T> - std::vector<float> Dequantize(TfLiteTensor* tensor); - - /** - * @brief Calculates the softmax of vector in place. **/ - void Softmax(std::vector<float>& inputVector); - - - /** @brief Given a wav file name return AD model output index. - * @param[in] wavFileName Audio WAV filename. - * File name should be in format anything_goes_XX_here.wav - * where XX is the machine ID e.g. 00, 02, 04 or 06 - * @return AD model output index as 8 bit integer. - **/ - int8_t OutputIndexFromFileName(std::string wavFileName); - -} /* namespace app */ -} /* namespace arm */ - -#endif /* ADPOSTPROCESSING_HPP */ diff --git a/source/use_case/ad/include/AdProcessing.hpp b/source/use_case/ad/include/AdProcessing.hpp new file mode 100644 index 0000000..9abf6f1 --- /dev/null +++ b/source/use_case/ad/include/AdProcessing.hpp @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AD_PROCESSING_HPP +#define AD_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "AudioUtils.hpp" +#include "AdMelSpectrogram.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for anomaly detection use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class AdPreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor for AdPreProcess class objects + * @param[in] inputTensor input tensor pointer from the tensor arena. + * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length + * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride + * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used. + */ + explicit AdPreProcess(TfLiteTensor* inputTensor, + uint32_t melSpectrogramFrameLen, + uint32_t melSpectrogramFrameStride, + float adModelTrainingMean); + + ~AdPreProcess() = default; + + /** + * @brief Function to invoke pre-processing and populate the input vector + * @param input pointer to input data. For anomaly detection, this is the pointer to + * the audio data. + * @param inputSize Size of the data being passed in for pre-processing. + * @return True if successful, false otherwise. + */ + bool DoPreProcess(const void* input, size_t inputSize) override; + + /** + * @brief Getter function for audio window size computed when constructing + * the class object. + * @return Audio window size as 32 bit unsigned integer. + */ + uint32_t GetAudioWindowSize(); + + /** + * @brief Getter function for audio window stride computed when constructing + * the class object. + * @return Audio window stride as 32 bit unsigned integer. + */ + uint32_t GetAudioDataStride(); + + /** + * @brief Setter function for current audio index. This is only used for evaluating + * if previously computed features can be re-used from cache. + */ + void SetAudioWindowIndex(uint32_t idx); + + private: + bool m_validInstance{false}; /**< Indicates the current object is valid. */ + uint32_t m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */ + uint32_t m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */ + uint8_t m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */ + uint32_t m_numMelSpecVectorsInAudioStride{}; /**< Number of frames to move across the audio. */ + uint32_t m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */ + uint32_t m_audioDataStride{}; /**< Audio window stride computed. */ + uint32_t m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */ + uint32_t m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */ + + audio::SlidingWindow<const int16_t> m_melWindowSlider; /**< Internal MEL spectrogram window slider */ + audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */ + std::function<void + (std::vector<int16_t>&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */ + }; + + class AdPostProcess : public BasePostProcess { + public: + /** + * @brief Constructor for AdPostProcess object. + * @param[in] outputTensor Output tensor pointer. + */ + explicit AdPostProcess(TfLiteTensor* outputTensor); + + ~AdPostProcess() = default; + + /** + * @brief Function to do the post-processing on the output tensor. + * @return True if successful, false otherwise. + */ + bool DoPostProcess() override; + + /** + * @brief Getter function for an element from the de-quantised output vector. + * @param index Index of the element to be retrieved. + * @return index represented as a 32 bit floating point number. + */ + float GetOutputValue(uint32_t index); + + private: + TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */ + std::vector<float> m_dequantizedOutputVec{}; /**< Internal output vector */ + + /** + * @brief De-quantizes and flattens the output tensor into a vector. + * @tparam T template parameter to indicate data type. + * @return True if successful, false otherwise. + */ + template<typename T> + bool Dequantize() + { + TfLiteTensor* tensor = this->m_outputTensor; + if (tensor == nullptr) { + printf_err("Invalid output tensor.\n"); + return false; + } + T* tensorData = tflite::GetTensorData<T>(tensor); + + uint32_t totalOutputSize = 1; + for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){ + totalOutputSize *= tensor->dims->data[inputDim]; + } + + /* For getting the floating point values, we need quantization parameters */ + QuantParams quantParams = GetTensorQuantParams(tensor); + + this->m_dequantizedOutputVec = std::vector<float>(totalOutputSize, 0); + + for (size_t i = 0; i < totalOutputSize; ++i) { + this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset); + } + + return true; + } + }; + + /* Templated instances available: */ + template bool AdPostProcess::Dequantize<int8_t>(); + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T feature vector type. + * @param inputTensor model input tensor pointer. + * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. + * @param compute features calculator function. + * @return lambda function to compute features. + */ + template<class T> + std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> + FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function<std::vector<T> (std::vector<int16_t>& )> compute) + { + /* Feature cache to be captured by lambda function*/ + static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); + + return [=](std::vector<int16_t>& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex, + size_t resizeScale) + { + T* tensorData = tflite::GetTensorData<T>(inputTensor); + std::vector<T> features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size() / resizeScale; + auto sizeBytes = sizeof(T); + + /* Input should be transposed and "resized" by skipping elements. */ + for (size_t outIndex = 0; outIndex < size; outIndex++) { + std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes); + } + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex / resizeScale) { + featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features); + } + }; + } + + template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> + FeatureCalc<int8_t>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); + + template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)> + FeatureCalc<float>(TfLiteTensor *inputTensor, + size_t cacheSize, + std::function<std::vector<float>(std::vector<int16_t>&)> compute); + + std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> + GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, + TfLiteTensor* inputTensor, + size_t cacheSize, + float trainingMean); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* AD_PROCESSING_HPP */ diff --git a/source/use_case/ad/src/AdPostProcessing.cc b/source/use_case/ad/src/AdPostProcessing.cc deleted file mode 100644 index c461875..0000000 --- a/source/use_case/ad/src/AdPostProcessing.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "AdPostProcessing.hpp" -#include "log_macros.h" - -#include <numeric> -#include <cmath> -#include <string> - -namespace arm { -namespace app { - - template<typename T> - std::vector<float> Dequantize(TfLiteTensor* tensor) { - - if (tensor == nullptr) { - printf_err("Tensor is null pointer can not dequantize.\n"); - return std::vector<float>(); - } - T* tensorData = tflite::GetTensorData<T>(tensor); - - uint32_t totalOutputSize = 1; - for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){ - totalOutputSize *= tensor->dims->data[inputDim]; - } - - /* For getting the floating point values, we need quantization parameters */ - QuantParams quantParams = GetTensorQuantParams(tensor); - - std::vector<float> dequantizedOutput(totalOutputSize); - - for (size_t i = 0; i < totalOutputSize; ++i) { - dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset); - } - - return dequantizedOutput; - } - - void Softmax(std::vector<float>& inputVector) { - auto start = inputVector.begin(); - auto end = inputVector.end(); - - /* Fix for numerical stability and apply exp. */ - float maxValue = *std::max_element(start, end); - for (auto it = start; it!=end; ++it) { - *it = std::exp((*it) - maxValue); - } - - float sumExp = std::accumulate(start, end, 0.0f); - - for (auto it = start; it!=end; ++it) { - *it = (*it)/sumExp; - } - } - - int8_t OutputIndexFromFileName(std::string wavFileName) { - /* Filename is assumed in the form machine_id_00.wav */ - std::string delimiter = "_"; /* First character used to split the file name up. */ - size_t delimiterStart; - std::string subString; - size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ - - for (size_t i = 0; i < machineIdxInString; ++i) { - delimiterStart = wavFileName.find(delimiter); - subString = wavFileName.substr(0, delimiterStart); - wavFileName.erase(0, delimiterStart + delimiter.length()); - } - - /* At this point substring should be 00.wav */ - delimiter = "."; /* Second character used to split the file name up. */ - delimiterStart = subString.find(delimiter); - subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - - auto is_number = [](const std::string& str) -> bool - { - std::string::const_iterator it = str.begin(); - while (it != str.end() && std::isdigit(*it)) ++it; - return !str.empty() && it == str.end(); - }; - - const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1; - - /* Return corresponding index in the output vector. */ - if (machineIdx == 0) { - return 0; - } else if (machineIdx == 2) { - return 1; - } else if (machineIdx == 4) { - return 2; - } else if (machineIdx == 6) { - return 3; - } else { - printf_err("%d is an invalid machine index \n", machineIdx); - return -1; - } - } - - template std::vector<float> Dequantize<uint8_t>(TfLiteTensor* tensor); - template std::vector<float> Dequantize<int8_t>(TfLiteTensor* tensor); -} /* namespace app */ -} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/ad/src/AdProcessing.cc b/source/use_case/ad/src/AdProcessing.cc new file mode 100644 index 0000000..a33131c --- /dev/null +++ b/source/use_case/ad/src/AdProcessing.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AdProcessing.hpp" + +#include "AdModel.hpp" + +namespace arm { +namespace app { + +AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor, + uint32_t melSpectrogramFrameLen, + uint32_t melSpectrogramFrameStride, + float adModelTrainingMean): + m_validInstance{false}, + m_melSpectrogramFrameLen{melSpectrogramFrameLen}, + m_melSpectrogramFrameStride{melSpectrogramFrameStride}, + /**< Model is trained on features downsampled 2x */ + m_inputResizeScale{2}, + /**< We are choosing to move by 20 frames across the audio for each inference. */ + m_numMelSpecVectorsInAudioStride{20}, + m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride}, + m_melSpec{melSpectrogramFrameLen} +{ + if (!inputTensor) { + printf_err("Invalid input tensor provided to pre-process\n"); + return; + } + + TfLiteIntArray* inputShape = inputTensor->dims; + + if (!inputShape) { + printf_err("Invalid input tensor dims\n"); + return; + } + + const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx]; + const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx]; + + /* Deduce the data length required for 1 inference from the network parameters. */ + this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) * + melSpectrogramFrameStride) + + melSpectrogramFrameLen; + this->m_numReusedFeatureVectors = kNumRows - + (this->m_numMelSpecVectorsInAudioStride / + this->m_inputResizeScale); + this->m_melSpec.Init(); + + /* Creating a Mel Spectrogram sliding window for the data required for 1 inference. + * "resizing" done here by multiplying stride by resize scale. */ + this->m_melWindowSlider = audio::SlidingWindow<const int16_t>( + nullptr, /* to be populated later. */ + this->m_audioDataWindowSize, + melSpectrogramFrameLen, + melSpectrogramFrameStride * this->m_inputResizeScale); + + /* Construct feature calculation function. */ + this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor, + this->m_numReusedFeatureVectors, + adModelTrainingMean); + this->m_validInstance = true; +} + +bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize) +{ + /* Check that we have a valid instance. */ + if (!this->m_validInstance) { + printf_err("Invalid pre-processor instance\n"); + return false; + } + + /* We expect that we can traverse the size with which the MEL spectrogram + * sliding window was initialised with. */ + if (!input || inputSize < this->m_audioDataWindowSize) { + printf_err("Invalid input provided for pre-processing\n"); + return false; + } + + /* We moved to the next window - set the features sliding to the new address. */ + this->m_melWindowSlider.Reset(static_cast<const int16_t*>(input)); + + /* The first window does not have cache ready. */ + const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0; + + /* Start calculating features inside one audio sliding window. */ + while (this->m_melWindowSlider.HasNext()) { + const int16_t* melSpecWindow = this->m_melWindowSlider.Next(); + std::vector<int16_t> melSpecAudioData = std::vector<int16_t>( + melSpecWindow, + melSpecWindow + this->m_melSpectrogramFrameLen); + + /* Compute features for this window and write them to input tensor. */ + this->m_featureCalc(melSpecAudioData, + this->m_melWindowSlider.Index(), + useCache, + this->m_numMelSpecVectorsInAudioStride, + this->m_inputResizeScale); + } + + return true; +} + +uint32_t AdPreProcess::GetAudioWindowSize() +{ + return this->m_audioDataWindowSize; +} + +uint32_t AdPreProcess::GetAudioDataStride() +{ + return this->m_audioDataStride; +} + +void AdPreProcess::SetAudioWindowIndex(uint32_t idx) +{ + this->m_audioWindowIndex = idx; +} + +AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) : + m_outputTensor {outputTensor} +{} + +bool AdPostProcess::DoPostProcess() +{ + switch (this->m_outputTensor->type) { + case kTfLiteInt8: + this->Dequantize<int8_t>(); + break; + default: + printf_err("Unsupported tensor type"); + return false; + } + + math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec); + return true; +} + +float AdPostProcess::GetOutputValue(uint32_t index) +{ + if (index < this->m_dequantizedOutputVec.size()) { + return this->m_dequantizedOutputVec[index]; + } + printf_err("Invalid index for output\n"); + return 0.0; +} + +std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> +GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, + TfLiteTensor* inputTensor, + size_t cacheSize, + float trainingMean) +{ + std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + + auto* quantParams = static_cast<TfLiteAffineQuantization*>(quant.params); + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + melSpecFeatureCalc = FeatureCalc<int8_t>( + inputTensor, + cacheSize, + [=, &melSpec](std::vector<int16_t>& audioDataWindow) { + return melSpec.MelSpecComputeQuant<int8_t>( + audioDataWindow, + quantScale, + quantOffset, + trainingMean); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + } else { + melSpecFeatureCalc = FeatureCalc<float>( + inputTensor, + cacheSize, + [=, &melSpec]( + std::vector<int16_t>& audioDataWindow) { + return melSpec.ComputeMelSpec( + audioDataWindow, + trainingMean); + }); + } + return melSpecFeatureCalc; +} + +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc index 23d1e51..140359b 100644 --- a/source/use_case/ad/src/MainLoop.cc +++ b/source/use_case/ad/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions */ #include "InputFiles.hpp" /* For input data */ #include "AdModel.hpp" /* Model class for running inference */ #include "UseCaseCommonUtils.hpp" /* Utils functions */ @@ -63,8 +62,8 @@ void main_loop() caseContext.Set<arm::app::Profiler&>("profiler", profiler); caseContext.Set<arm::app::Model&>("model", model); caseContext.Set<uint32_t>("clipIndex", 0); - caseContext.Set<int>("frameLength", g_FrameLength); - caseContext.Set<int>("frameStride", g_FrameStride); + caseContext.Set<uint32_t>("frameLength", g_FrameLength); + caseContext.Set<uint32_t>("frameStride", g_FrameStride); caseContext.Set<float>("scoreThreshold", g_ScoreThreshold); caseContext.Set<float>("trainingMean", g_TrainingMean); diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc index 5585f36..0179d6b 100644 --- a/source/use_case/ad/src/UseCaseHandler.cc +++ b/source/use_case/ad/src/UseCaseHandler.cc @@ -24,8 +24,8 @@ #include "AudioUtils.hpp" #include "ImageUtils.hpp" #include "UseCaseCommonUtils.hpp" -#include "AdPostProcessing.hpp" #include "log_macros.h" +#include "AdProcessing.hpp" namespace arm { namespace app { @@ -39,32 +39,17 @@ namespace app { **/ static bool PresentInferenceResult(float result, float threshold); - /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MelSpe data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] melSpec MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors). - * @param[in] trainingMean Training mean. - * @return function function to be called providing audio sample and sliding window index. - */ - static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> - GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, - TfLiteTensor* inputTensor, - size_t cacheSize, - float trainingMean); - - /* Vibration classification handler */ + /** @brief Given a wav file name return AD model output index. + * @param[in] wavFileName Audio WAV filename. + * File name should be in format anything_goes_XX_here.wav + * where XX is the machine ID e.g. 00, 02, 04 or 06 + * @return AD model output index as 8 bit integer. + **/ + static int8_t OutputIndexFromFileName(std::string wavFileName); + + /* Anomaly Detection inference handler */ bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - auto& profiler = ctx.Get<Profiler&>("profiler"); - constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -81,8 +66,9 @@ namespace app { return false; } - const auto frameLength = ctx.Get<int>("frameLength"); - const auto frameStride = ctx.Get<int>("frameStride"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength"); + const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride"); const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); const auto trainingMean = ctx.Get<float>("trainingMean"); auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); @@ -95,21 +81,13 @@ namespace app { return false; } - TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t kNumRows = inputShape->data[1]; - const uint32_t kNumCols = inputShape->data[2]; + AdPreProcess preProcess{ + inputTensor, + melSpecFrameLength, + melSpecFrameStride, + trainingMean}; - audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength); - melSpec.Init(); - - /* Deduce the data length required for 1 inference from the network parameters. */ - const uint8_t inputResizeScale = 2; - const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength; - - /* We are choosing to move by 20 frames across the audio for each inference. */ - const uint8_t nMelSpecVectorsInAudioStride = 20; - - auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride; + AdPostProcess postProcess{outputTensor}; do { hal_lcd_clear(COLOR_BLACK); @@ -122,29 +100,12 @@ namespace app { return false; } - /* Creating a Mel Spectrogram sliding window for the data required for 1 inference. - * "resizing" done here by multiplying stride by resize scale. */ - auto audioMelSpecWindowSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - audioDataWindowSize, frameLength, - frameStride * inputResizeScale); - /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - audioDataWindowSize, audioDataStride); - - /* Calculate number of the feature vectors in the window overlap region taking into account resizing. - * These feature vectors will be reused.*/ - auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale); - - /* Construct feature calculation function. */ - auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor, - numberOfReusedFeatureVectors, trainingMean); - if (!melSpecFeatureCalc){ - return false; - } + get_audio_array(currentIndex), + get_audio_array_size(currentIndex), + preProcess.GetAudioWindowSize(), + preProcess.GetAudioDataStride()); /* Result is an averaged sum over inferences. */ float result = 0; @@ -152,30 +113,18 @@ namespace app { /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex)); + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + info("Running inference on audio clip %" PRIu32 " => %s\n", + currentIndex, get_filename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { - const int16_t *inferenceWindow = audioDataSlider.Next(); - - /* We moved to the next window - set the features sliding to the new address. */ - audioMelSpecWindowSlider.Reset(inferenceWindow); + const int16_t* inferenceWindow = audioDataSlider.Next(); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (audioMelSpecWindowSlider.HasNext()) { - const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next(); - std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(melSpecWindow, - melSpecWindow + frameLength); - - /* Compute features for this window and write them to input tensor. */ - melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(), - useCache, nMelSpecVectorsInAudioStride, inputResizeScale); - } + preProcess.SetAudioWindowIndex(audioDataSlider.Index()); + preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize()); info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); @@ -185,13 +134,11 @@ namespace app { return false; } - /* Use the negative softmax score of the corresponding index as the outlier score */ - std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor); - Softmax(dequantOutput); - result += -dequantOutput[machineOutputIndex]; + postProcess.DoPostProcess(); + result += 0 - postProcess.GetOutputValue(machineOutputIndex); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(outputTensor); + DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ @@ -218,7 +165,6 @@ namespace app { return true; } - static bool PresentInferenceResult(float result, float threshold) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -251,148 +197,47 @@ namespace app { return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - */ - template<class T> - std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute) + static int8_t OutputIndexFromFileName(std::string wavFileName) { - /* Feature cache to be captured by lambda function*/ - static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); - - return [=](std::vector<int16_t>& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex, - size_t resizeScale) - { - T *tensorData = tflite::GetTensorData<T>(inputTensor); - std::vector<T> features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size() / resizeScale; - auto sizeBytes = sizeof(T); + /* Filename is assumed in the form machine_id_00.wav */ + std::string delimiter = "_"; /* First character used to split the file name up. */ + size_t delimiterStart; + std::string subString; + size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ + + for (size_t i = 0; i < machineIdxInString; ++i) { + delimiterStart = wavFileName.find(delimiter); + subString = wavFileName.substr(0, delimiterStart); + wavFileName.erase(0, delimiterStart + delimiter.length()); + } - /* Input should be transposed and "resized" by skipping elements. */ - for (size_t outIndex = 0; outIndex < size; outIndex++) { - std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes); - } + /* At this point substring should be 00.wav */ + delimiter = "."; /* Second character used to split the file name up. */ + delimiterStart = subString.find(delimiter); + subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex / resizeScale) { - featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features); - } + auto is_number = [](const std::string& str) -> bool + { + std::string::const_iterator it = str.begin(); + while (it != str.end() && std::isdigit(*it)) ++it; + return !str.empty() && it == str.end(); }; - } - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> - FeatureCalc<int8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> - FeatureCalc<uint8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<uint8_t> (std::vector<int16_t>&)> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> - FeatureCalc<int16_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int16_t> (std::vector<int16_t>&)> compute); - - template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)> - FeatureCalc<float>(TfLiteTensor *inputTensor, - size_t cacheSize, - std::function<std::vector<float>(std::vector<int16_t>&)> compute); - - - static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> - GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean) - { - std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto *quantParams = (TfLiteAffineQuantization *) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - melSpecFeatureCalc = FeatureCalc<int8_t>(inputTensor, - cacheSize, - [=, &melSpec](std::vector<int16_t>& audioDataWindow) { - return melSpec.MelSpecComputeQuant<int8_t>( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - case kTfLiteUInt8: { - melSpecFeatureCalc = FeatureCalc<uint8_t>(inputTensor, - cacheSize, - [=, &melSpec](std::vector<int16_t>& audioDataWindow) { - return melSpec.MelSpecComputeQuant<uint8_t>( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - case kTfLiteInt16: { - melSpecFeatureCalc = FeatureCalc<int16_t>(inputTensor, - cacheSize, - [=, &melSpec](std::vector<int16_t>& audioDataWindow) { - return melSpec.MelSpecComputeQuant<int16_t>( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - + const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1; + + /* Return corresponding index in the output vector. */ + if (machineIdx == 0) { + return 0; + } else if (machineIdx == 2) { + return 1; + } else if (machineIdx == 4) { + return 2; + } else if (machineIdx == 6) { + return 3; } else { - melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc<float>(inputTensor, - cacheSize, - [=, &melSpec]( - std::vector<int16_t>& audioDataWindow) { - return melSpec.ComputeMelSpec( - audioDataWindow, - trainingMean); - }); + printf_err("%d is an invalid machine index \n", machineIdx); + return -1; } - return melSpecFeatureCalc; } } /* namespace app */ diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp index 0078e44..bec70ab 100644 --- a/source/use_case/asr/include/Wav2LetterModel.hpp +++ b/source/use_case/asr/include/Wav2LetterModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved.rved. + * Copyright (c) 2021 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/source/use_case/kws_asr/include/KwsProcessing.hpp b/source/use_case/kws_asr/include/KwsProcessing.hpp new file mode 100644 index 0000000..d3de3b3 --- /dev/null +++ b/source/use_case/kws_asr/include/KwsProcessing.hpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_PROCESSING_HPP +#define KWS_PROCESSING_HPP + +#include <AudioUtils.hpp> +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "Classifier.hpp" +#include "MicroNetKwsMfcc.hpp" + +#include <functional> + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Keyword Spotting use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class KwsPreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numFeatures How many MFCC features to use. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when + * sliding a window through the audio sample. + * @param[in] mfccFrameStride Number of audio samples between consecutive windows. + **/ + explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames, + int mfccFrameLength, int mfccFrameStride); + + /** + * @brief Should perform pre-processing of 'raw' input audio data and load it into + * TFLite Micro input tensors ready for inference. + * @param[in] input Pointer to the data that pre-processing will work on. + * @param[in] inputSize Size of the input data. + * @return true if successful, false otherwise. + **/ + bool DoPreProcess(const void* input, size_t inputSize) override; + + size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */ + size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */ + size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */ + + private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ + const int m_mfccFrameLength; + const int m_mfccFrameStride; + const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */ + + audio::MicroNetKwsMFCC m_mfcc; + audio::SlidingWindow<const int16_t> m_mfccSlidingWindow; + size_t m_numMfccVectorsInAudioStride; + size_t m_numReusedMfccVectors; + std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator; + + /** + * @brief Returns a function to perform feature calculation and populates input tensor data with + * MFCC data. + * + * Input tensor data type check is performed to choose correct MFCC feature data type. + * If tensor has an integer data type then original features are quantised. + * + * Warning: MFCC calculator provided as input must have the same life scope as returned function. + * + * @param[in] mfcc MFCC feature calculator. + * @param[in,out] inputTensor Input tensor pointer to store calculated features. + * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). + * @return Function to be called providing audio sample and sliding window index. + */ + std::function<void (std::vector<int16_t>&, int, bool, size_t)> + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, + TfLiteTensor* inputTensor, + size_t cacheSize); + + template<class T> + std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> + FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function<std::vector<T> (std::vector<int16_t>& )> compute); + }; + + /** + * @brief Post-processing class for Keyword Spotting use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class KwsPostProcess : public BasePostProcess { + + private: + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + Classifier& m_kwsClassifier; /* KWS Classifier object. */ + const std::vector<std::string>& m_labels; /* KWS Labels. */ + std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */ + + public: + /** + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. + **/ + KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, + const std::vector<std::string>& labels, + std::vector<ClassificationResult>& results); + + /** + * @brief Should perform post-processing of the result of inference then + * populate KWS result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_PROCESSING_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp index 43bd390..af6ba5f 100644 --- a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp +++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,7 +24,7 @@ namespace app { namespace audio { /* Class to provide MicroNet specific MFCC calculation requirements. */ - class MicroNetMFCC : public MFCC { + class MicroNetKwsMFCC : public MFCC { public: static constexpr uint32_t ms_defaultSamplingFreq = 16000; @@ -34,14 +34,14 @@ namespace audio { static constexpr bool ms_defaultUseHtkMethod = true; - explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen) + explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen) : MFCC(MfccParams( ms_defaultSamplingFreq, ms_defaultNumFbankBins, ms_defaultMelLoFreq, ms_defaultMelHiFreq, numFeats, frameLen, ms_defaultUseHtkMethod)) {} - MicroNetMFCC() = delete; - ~MicroNetMFCC() = default; + MicroNetKwsMFCC() = delete; + ~MicroNetKwsMFCC() = default; }; } /* namespace audio */ diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp index 7c327b3..0e1adc5 100644 --- a/source/use_case/kws_asr/include/Wav2LetterModel.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,14 +34,18 @@ namespace arm { namespace app { class Wav2LetterModel : public Model { - + public: /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 1; - static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; + /* Model specific constants. */ + static constexpr uint32_t ms_blankTokenIdx = 28; + static constexpr uint32_t ms_numMfccFeatures = 13; + protected: /** @brief Gets the reference to op resolver interface class. */ const tflite::MicroOpResolver& GetOpResolver() override; diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp index 029a641..d1bc9a2 100644 --- a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,88 +14,95 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_WAV2LET_POSTPROC_HPP -#define KWS_ASR_WAV2LET_POSTPROC_HPP +#ifndef KWS_ASR_WAV2LETTER_POSTPROCESS_HPP +#define KWS_ASR_WAV2LETTER_POSTPROCESS_HPP -#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers */ +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "BaseProcessing.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" +#include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /** * @brief Helper class to manage tensor post-processing for "wav2letter" * output. */ - class Postprocess { + class AsrPostProcess : public BasePostProcess { public: + bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ + /** - * @brief Constructor - * @param[in] contextLen Left and right context length for - * output tensor. - * @param[in] innerLen This is the length of the section - * between left and right context. - * @param[in] blankTokenIdx Blank token index. + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in] outputContextLen Left/right context length for output tensor. + * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. + * @param[in] reductionAxis The axis that the logits of each time step is on. **/ - Postprocess(uint32_t contextLen, - uint32_t innerLen, - uint32_t blankTokenIdx); - - Postprocess() = delete; - ~Postprocess() = default; + AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector<std::string>& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); /** - * @brief Erases the required part of the tensor based - * on context lengths set up during initialisation - * @param[in] tensor Pointer to the tensor - * @param[in] axisIdx Index of the axis on which erase is - * performed. - * @param[in] lastIteration Flag to signal is this is the - * last iteration in which case - * the right context is preserved. - * @return true if successful, false otherwise. - */ - bool Invoke(TfLiteTensor* tensor, - uint32_t axisIdx, - bool lastIteration = false); + * @brief Should perform post-processing of the result of inference then + * populate ASR result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + /** @brief Gets the output inner length for post-processing. */ + static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen); + + /** @brief Gets the output context length (left/right) for post-processing. */ + static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen); + + /** @brief Gets the number of feature vectors to be computed. */ + static uint32_t GetNumFeatureVectors(const Model& model); private: - uint32_t m_contextLen; /* Lengths of left and right contexts. */ - uint32_t m_innerLen; /* Length of inner context. */ - uint32_t m_totalLen; /* Total length of the required axis. */ - uint32_t m_countIterations; /* Current number of iterations. */ - uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + AsrClassifier& m_classifier; /* ASR Classifier object. */ + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + const std::vector<std::string>& m_labels; /* ASR Labels. */ + asr::ResultVec & m_results; /* Results vector for a single inference. */ + uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */ + uint32_t m_outputInnerLen; /* Length of output inner context. */ + uint32_t m_totalLen; /* Total length of the required axis. */ + uint32_t m_countIterations; /* Current number of iterations. */ + uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */ + /** - * @brief Checks if the tensor and axis index are valid - * inputs to the object - based on how it has been - * initialised. - * @return true if valid, false otherwise. + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been initialised. + * @return true if valid, false otherwise. */ bool IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const; + uint32_t axisIdx) const; /** - * @brief Gets the tensor data element size in bytes based - * on the tensor type. - * @return Size in bytes, 0 if not supported. + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. */ - uint32_t GetTensorElementSize(TfLiteTensor* tensor); + static uint32_t GetTensorElementSize(TfLiteTensor* tensor); /** - * @brief Erases sections from the data assuming row-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. */ bool EraseSectionsRowWise(uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration); - + uint32_t strideSzBytes, + bool lastIteration); }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_WAV2LET_POSTPROC_HPP */
\ No newline at end of file +#endif /* KWS_ASR_WAV2LETTER_POSTPROCESS_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp index 3609c49..1224c23 100644 --- a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,56 +14,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_WAV2LET_PREPROC_HPP -#define KWS_ASR_WAV2LET_PREPROC_HPP +#ifndef KWS_ASR_WAV2LETTER_PREPROCESS_HPP +#define KWS_ASR_WAV2LETTER_PREPROCESS_HPP #include "Wav2LetterModel.hpp" #include "Wav2LetterMfcc.hpp" #include "AudioUtils.hpp" #include "DataStructures.hpp" +#include "BaseProcessing.hpp" #include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /* Class to facilitate pre-processing calculation for Wav2Letter model * for ASR. */ - using AudioWindow = SlidingWindow <const int16_t>; + using AudioWindow = audio::SlidingWindow<const int16_t>; - class Preprocess { + class AsrPreProcess : public BasePreProcess { public: /** - * @brief Constructor - * @param[in] numMfccFeatures Number of MFCC features per window. - * @param[in] windowLen Number of elements in a window. - * @param[in] windowStride Stride (in number of elements) for - * moving the window. - * @param[in] numMfccVectors Number of MFCC vectors per window. - */ - Preprocess( - uint32_t numMfccFeatures, - uint32_t windowLen, - uint32_t windowStride, - uint32_t numMfccVectors); - Preprocess() = delete; - ~Preprocess() = default; + * @brief Constructor. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. + * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. + */ + AsrPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t numFeatureFrames, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); /** * @brief Calculates the features required from audio data. This * includes MFCC, first and second order deltas, * normalisation and finally, quantisation. The tensor is - * populated with feature from a given window placed along + * populated with features from a given window placed along * in a single row. * @param[in] audioData Pointer to the first element of audio data. * @param[in] audioDataLen Number of elements in the audio data. - * @param[in] tensor Tensor to be populated. * @return true if successful, false in case of error. */ - bool Invoke(const int16_t * audioData, - uint32_t audioDataLen, - TfLiteTensor * tensor); + bool DoPreProcess(const void* audioData, size_t audioDataLen) override; protected: /** @@ -73,49 +68,32 @@ namespace asr { * @param[in] mfcc MFCC buffers. * @param[out] delta1 Result of the first diff computation. * @param[out] delta2 Result of the second diff computation. - * - * @return true if successful, false otherwise. + * @return true if successful, false otherwise. */ static bool ComputeDeltas(Array2d<float>& mfcc, Array2d<float>& delta1, Array2d<float>& delta2); /** - * @brief Given a 2D vector of floats, computes the mean. - * @param[in] vec Vector of vector of floats. - * @return Mean value. - */ - static float GetMean(Array2d<float>& vec); - - /** - * @brief Given a 2D vector of floats, computes the stddev. - * @param[in] vec Vector of vector of floats. - * @param[in] mean Mean value of the vector passed in. - * @return stddev value. - */ - static float GetStdDev(Array2d<float>& vec, - const float mean); - - /** - * @brief Given a 2D vector of floats, normalises it using - * the mean and the stddev + * @brief Given a 2D vector of floats, rescale it to have mean of 0 and + * standard deviation of 1. * @param[in,out] vec Vector of vector of floats. */ - static void NormaliseVec(Array2d<float>& vec); + static void StandardizeVecF32(Array2d<float>& vec); /** - * @brief Normalises the MFCC and delta buffers. + * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1. */ - void Normalise(); + void Standarize(); /** * @brief Given the quantisation and data type limits, computes * the quantised values of a floating point input data. - * @param[in] elem Element to be quantised. - * @param[in] quantScale Scale. - * @param[in] quantOffset Offset. - * @param[in] minVal Numerical limit - minimum. - * @param[in] maxVal Numerical limit - maximum. + * @param[in] elem Element to be quantised. + * @param[in] quantScale Scale. + * @param[in] quantOffset Offset. + * @param[in] minVal Numerical limit - minimum. + * @param[in] maxVal Numerical limit - maximum. * @return Floating point quantised value. */ static float GetQuantElem( @@ -133,44 +111,43 @@ namespace asr { * this being the convolution speed up (as we can use * contiguous memory). The output, however, requires the * time axis to be in column major arrangement. - * @param[in] outputBuf Pointer to the output buffer. - * @param[in] outputBufSz Output buffer's size. - * @param[in] quantScale Quantisation scale. - * @param[in] quantOffset Quantisation offset. + * @param[in] outputBuf Pointer to the output buffer. + * @param[in] outputBufSz Output buffer's size. + * @param[in] quantScale Quantisation scale. + * @param[in] quantOffset Quantisation offset. */ template <typename T> bool Quantise( - T * outputBuf, + T* outputBuf, const uint32_t outputBufSz, const float quantScale, const int quantOffset) { - /* Check the output size will for everything. */ + /* Check the output size will fit everything. */ if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) { printf_err("Tensor size too small for features\n"); return false; } /* Populate. */ - T * outputBufMfcc = outputBuf; - T * outputBufD1 = outputBuf + this->m_numMfccFeats; - T * outputBufD2 = outputBufD1 + this->m_numMfccFeats; + T* outputBufMfcc = outputBuf; + T* outputBufD1 = outputBuf + this->m_numMfccFeats; + T* outputBufD2 = outputBufD1 + this->m_numMfccFeats; const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ const float minVal = std::numeric_limits<T>::min(); const float maxVal = std::numeric_limits<T>::max(); - /* We need to do a transpose while copying and concatenating - * the tensor. */ - for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) { + /* Need to transpose while copying and concatenating the tensor. */ + for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast<T>(this->GetQuantElem( + *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_mfccBuf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast<T>(this->GetQuantElem( + *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_delta1Buf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast<T>(this->GetQuantElem( + *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_delta2Buf(i, j), quantScale, quantOffset, minVal, maxVal)); } @@ -183,24 +160,23 @@ namespace asr { } private: - Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + TfLiteTensor* m_inputTensor; /* Model input tensor. */ /* Actual buffers to be populated. */ - Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */ - Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ - Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ - uint32_t m_windowLen; /* Window length for MFCC. */ - uint32_t m_windowStride; /* Window stride len for MFCC. */ - uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ - uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */ - AudioWindow m_window; /* Sliding window. */ + uint32_t m_mfccWindowLen; /* Window length for MFCC. */ + uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */ + uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */ + AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */ }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_WAV2LET_PREPROC_HPP */
\ No newline at end of file +#endif /* KWS_ASR_WAV2LETTER_PREPROCESS_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/KwsProcessing.cc b/source/use_case/kws_asr/src/KwsProcessing.cc new file mode 100644 index 0000000..328709d --- /dev/null +++ b/source/use_case/kws_asr/src/KwsProcessing.cc @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "KwsProcessing.hpp" +#include "ImageUtils.hpp" +#include "log_macros.h" +#include "MicroNetKwsModel.hpp" + +namespace arm { +namespace app { + + KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames, + int mfccFrameLength, int mfccFrameStride + ): + m_inputTensor{inputTensor}, + m_mfccFrameLength{mfccFrameLength}, + m_mfccFrameStride{mfccFrameStride}, + m_numMfccFrames{numMfccFrames}, + m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)} + { + this->m_mfcc.Init(); + + /* Deduce the data length required for 1 inference from the network parameters. */ + this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride + + (this->m_mfccFrameLength - this->m_mfccFrameStride); + + /* Creating an MFCC feature sliding window for the data required for 1 inference. */ + this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(nullptr, this->m_audioDataWindowSize, + this->m_mfccFrameLength, this->m_mfccFrameStride); + + /* For longer audio clips we choose to move by half the audio window size + * => for a 1 second window size there is an overlap of 0.5 seconds. */ + this->m_audioDataStride = this->m_audioDataWindowSize / 2; + + /* To have the previously calculated features re-usable, stride must be multiple + * of MFCC features window stride. Reduce stride through audio if needed. */ + if (0 != this->m_audioDataStride % this->m_mfccFrameStride) { + this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride; + } + + this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride; + + /* Calculate number of the feature vectors in the window overlap region. + * These feature vectors will be reused.*/ + this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1 + - this->m_numMfccVectorsInAudioStride; + + /* Construct feature calculation function. */ + this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor, + this->m_numReusedMfccVectors); + + if (!this->m_mfccFeatureCalculator) { + printf_err("Feature calculator not initialized."); + } + } + + bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize) + { + UNUSED(inputSize); + if (data == nullptr) { + printf_err("Data pointer is null"); + } + + /* Set the features sliding window to the new address. */ + auto input = static_cast<const int16_t*>(data); + this->m_mfccSlidingWindow.Reset(input); + + /* Cache is only usable if we have more than 1 inference in an audio clip. */ + bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0; + + /* Use a sliding window to calculate MFCC features frame by frame. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); + + std::vector<int16_t> mfccFrameAudioData = std::vector<int16_t>(mfccWindow, + mfccWindow + this->m_mfccFrameLength); + + /* Compute features for this window and write them to input tensor. */ + this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(), + useCache, this->m_numMfccVectorsInAudioStride); + } + + debug("Input tensor populated \n"); + + return true; + } + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T Feature vector type. + * @param[in] inputTensor Model input tensor pointer. + * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap. + * @param[in] compute Features calculator function. + * @return Lambda function to compute features. + */ + template<class T> + std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function<std::vector<T> (std::vector<int16_t>& )> compute) + { + /* Feature cache to be captured by lambda function. */ + static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); + + return [=](std::vector<int16_t>& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex) + { + T* tensorData = tflite::GetTensorData<T>(inputTensor); + std::vector<T> features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size(); + auto sizeBytes = sizeof(T) * size; + std::memcpy(tensorData + (index * size), features.data(), sizeBytes); + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex) { + featureCache[index - featuresOverlapIndex] = std::move(features); + } + }; + } + + template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> + KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); + + template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<float>(std::vector<int16_t>&)> compute); + + + std::function<void (std::vector<int16_t>&, int, bool, size_t)> + KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + { + std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + auto *quantParams = (TfLiteAffineQuantization *) quant.params; + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + mfccFeatureCalc = this->FeatureCalc<int8_t>(inputTensor, + cacheSize, + [=, &mfcc](std::vector<int16_t>& audioDataWindow) { + return mfcc.MfccComputeQuant<int8_t>(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + } else { + mfccFeatureCalc = this->FeatureCalc<float>(inputTensor, cacheSize, + [&mfcc](std::vector<int16_t>& audioDataWindow) { + return mfcc.MfccCompute(audioDataWindow); } + ); + } + return mfccFeatureCalc; + } + + KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, + const std::vector<std::string>& labels, + std::vector<ClassificationResult>& results) + :m_outputTensor{outputTensor}, + m_kwsClassifier{classifier}, + m_labels{labels}, + m_results{results} + {} + + bool KwsPostProcess::DoPostProcess() + { + return this->m_kwsClassifier.GetClassificationResults( + this->m_outputTensor, this->m_results, + this->m_labels, 1, true); + } + +} /* namespace app */ +} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index 5c1d0e0..f1d97a0 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "InputFiles.hpp" /* For input images. */ #include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ @@ -24,8 +23,6 @@ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ -#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */ -#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */ #include "log_macros.h" using KwsClassifier = arm::app::Classifier; @@ -53,19 +50,8 @@ static void DisplayMenu() fflush(stdout); } -/** @brief Gets the number of MFCC features for a single window. */ -static uint32_t GetNumMfccFeatures(const arm::app::Model& model); - -/** @brief Gets the number of MFCC feature vectors to be computed. */ -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); - -/** @brief Gets the output context length (left and right) for post-processing. */ -static uint32_t GetOutputContextLen(const arm::app::Model& model, - uint32_t inputCtxLen); - -/** @brief Gets the output inner length for post-processing. */ -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - uint32_t outputCtxLen); +/** @brief Verify input and output tensor are of certain min dimensions. */ +static bool VerifyTensorDimensions(const arm::app::Model& model); void main_loop() { @@ -84,61 +70,46 @@ void main_loop() if (!asrModel.Init(kwsModel.GetAllocator())) { printf_err("Failed to initialise ASR model\n"); return; + } else if (!VerifyTensorDimensions(asrModel)) { + printf_err("Model's input or output dimension verification failed\n"); + return; } - /* Initialise ASR pre-processing. */ - arm::app::audio::asr::Preprocess prep( - GetNumMfccFeatures(asrModel), - arm::app::asr::g_FrameLength, - arm::app::asr::g_FrameStride, - GetNumMfccFeatureVectors(asrModel)); - - /* Initialise ASR post-processing. */ - const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen); - const uint32_t blankTokenIdx = 28; - arm::app::audio::asr::Postprocess postp( - outputCtxLen, - GetOutputInnerLen(asrModel, outputCtxLen), - blankTokenIdx); - /* Instantiate application context. */ arm::app::ApplicationContext caseContext; arm::app::Profiler profiler{"kws_asr"}; caseContext.Set<arm::app::Profiler&>("profiler", profiler); - caseContext.Set<arm::app::Model&>("kwsmodel", kwsModel); - caseContext.Set<arm::app::Model&>("asrmodel", asrModel); + caseContext.Set<arm::app::Model&>("kwsModel", kwsModel); + caseContext.Set<arm::app::Model&>("asrModel", asrModel); caseContext.Set<uint32_t>("clipIndex", 0); caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */ - caseContext.Set<int>("kwsframeLength", arm::app::kws::g_FrameLength); - caseContext.Set<int>("kwsframeStride", arm::app::kws::g_FrameStride); - caseContext.Set<float>("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength); + caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride); + caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc); caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins); - caseContext.Set<int>("asrframeLength", arm::app::asr::g_FrameLength); - caseContext.Set<int>("asrframeStride", arm::app::asr::g_FrameStride); - caseContext.Set<float>("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength); + caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride); + caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ KwsClassifier kwsClassifier; /* Classifier wrapper object. */ arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */ - caseContext.Set<arm::app::Classifier&>("kwsclassifier", kwsClassifier); - caseContext.Set<arm::app::AsrClassifier&>("asrclassifier", asrClassifier); - - caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep); - caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp); + caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier); + caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier); std::vector<std::string> asrLabels; arm::app::asr::GetLabelsVector(asrLabels); std::vector<std::string> kwsLabels; arm::app::kws::GetLabelsVector(kwsLabels); - caseContext.Set<const std::vector <std::string>&>("asrlabels", asrLabels); - caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels); + caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels); + caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels); /* KWS keyword that triggers ASR and associated checks */ - std::string triggerKeyword = std::string("yes"); + std::string triggerKeyword = std::string("no"); if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) { - caseContext.Set<const std::string &>("triggerkeyword", triggerKeyword); + caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword); } else { printf_err("Selected trigger keyword not found in labels file\n"); @@ -196,50 +167,26 @@ void main_loop() info("Main loop terminated.\n"); } -static uint32_t GetNumMfccFeatures(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; - if (0 != inputCols % 3) { - printf_err("Number of input columns is not a multiple of 3\n"); - } - return std::max(inputCols/3, 0); -} - -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) +static bool VerifyTensorDimensions(const arm::app::Model& model) { + /* Populate tensor related parameters. */ TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - return std::max(inputRows, 0); -} - -static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) -{ - const uint32_t inputRows = GetNumMfccFeatureVectors(model); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - - /* Check to make sure that the input tensor supports the above context and inner lengths. */ - if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { - printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", - inputCtxLen); - return 0; + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; } TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - - const float tensorColRatio = static_cast<float>(inputRows)/ - static_cast<float>(outputRows); - - return std::round(static_cast<float>(inputCtxLen)/tensorColRatio); -} + if (!outputTensor->dims) { + printf_err("Invalid output tensor dims\n"); + return false; + } else if (outputTensor->dims->size < 3) { + printf_err("Output tensor dimension should be >= 3\n"); + return false; + } -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - const uint32_t outputCtxLen) -{ - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - return (outputRows - (2 * outputCtxLen)); + return true; } diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 1e1a400..01aefae 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -28,6 +28,7 @@ #include "Wav2LetterMfcc.hpp" #include "Wav2LetterPreprocess.hpp" #include "Wav2LetterPostprocess.hpp" +#include "KwsProcessing.hpp" #include "AsrResult.hpp" #include "AsrClassifier.hpp" #include "OutputDecode.hpp" @@ -39,11 +40,6 @@ using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { - enum AsrOutputReductionAxis { - AxisRow = 1, - AxisCol = 2 - }; - struct KWSOutput { bool executionSuccess = false; const int16_t* asrAudioStart = nullptr; @@ -51,73 +47,53 @@ namespace app { }; /** - * @brief Presents kws inference results using the data presentation - * object. - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents KWS inference results. + * @param[in] results Vector of KWS classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results); + static bool PresentInferenceResult(std::vector<kws::KwsResult>& results); /** - * @brief Presents asr inference results using the data presentation - * object. - * @param[in] platform reference to the hal platform object - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results); + static bool PresentInferenceResult(std::vector<asr::AsrResult>& results); /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MFCC data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] mfcc MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). - * - * @return function function to be called providing audio sample and sliding window index. + * @brief Performs the KWS pipeline. + * @param[in,out] ctx pointer to the application context object + * @return struct containing pointer to audio data where ASR should begin + * and how much data to process. **/ - static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, - TfLiteTensor* inputTensor, - size_t cacheSize); + static KWSOutput doKws(ApplicationContext& ctx) + { + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& kwsModel = ctx.Get<Model&>("kwsModel"); + const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength"); + const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride"); + const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold"); + + auto currentIndex = ctx.Get<uint32_t>("clipIndex"); - /** - * @brief Performs the KWS pipeline. - * @param[in,out] ctx pointer to the application context object - * - * @return KWSOutput struct containing pointer to audio data where ASR should begin - * and how much data to process. - */ - static KWSOutput doKws(ApplicationContext& ctx) { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? - arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); + (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? + MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); - KWSOutput output; + /* Output struct from doing KWS. */ + KWSOutput output {}; - auto& profiler = ctx.Get<Profiler&>("profiler"); - auto& kwsModel = ctx.Get<Model&>("kwsmodel"); if (!kwsModel.IsInited()) { printf_err("KWS model has not been initialised\n"); return output; } - const int kwsFrameLength = ctx.Get<int>("kwsframeLength"); - const int kwsFrameStride = ctx.Get<int>("kwsframeStride"); - const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold"); - - TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); - + TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); if (!kwsInputTensor->dims) { printf_err("Invalid input tensor dims\n"); return output; @@ -126,63 +102,32 @@ namespace app { return output; } - const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc"); - const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins"); - - audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength); - kwsMfcc.Init(); - - /* Deduce the data length required for 1 KWS inference from the network parameters. */ - auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride + - (kwsFrameLength - kwsFrameStride); - auto kwsMfccWindowSize = kwsFrameLength; - auto kwsMfccWindowStride = kwsFrameStride; - - /* We are choosing to move by half the window size => for a 1 second window size, - * this means an overlap of 0.5 seconds. */ - auto kwsAudioDataStride = kwsAudioDataWindowSize / 2; - - info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize); - - /* Stride must be multiple of mfcc features window stride to re-use features. */ - if (0 != kwsAudioDataStride % kwsMfccWindowStride) { - kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride; - } - - auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride; + /* Get input shape for feature extraction. */ + TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); + const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; - auto currentIndex = ctx.Get<uint32_t>("clipIndex"); + /* Set up pre and post-processing. */ + KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames, + kwsMfccFrameLength, kwsMfccFrameStride); - /* Creating a mfcc features sliding window for the data required for 1 inference. */ - auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - kwsAudioDataWindowSize, kwsMfccWindowSize, - kwsMfccWindowStride); + std::vector<ClassificationResult> singleInfResult; + KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"), + ctx.Get<std::vector<std::string>&>("kwsLabels"), + singleInfResult); /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow<const int16_t>( get_audio_array(currentIndex), get_audio_array_size(currentIndex), - kwsAudioDataWindowSize, kwsAudioDataStride); - - /* Calculate number of the feature vectors in the window overlap region. - * These feature vectors will be reused.*/ - size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1 - - kwsMfccVectorsInAudioStride; + preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); - auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor, - numberOfReusedFeatureVectors); - - if (!kwsMfccFeatureCalc){ - return output; - } - - /* Container for KWS results. */ - std::vector<arm::app::kws::KwsResult> kwsResults; + /* Declare a container to hold kws results from across the whole audio clip. */ + std::vector<kws::KwsResult> finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running KWS inference... "}; @@ -197,70 +142,56 @@ namespace app { while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - /* We moved to the next window - set the features sliding to the new address. */ - kwsAudioMFCCWindowSlider.Reset(inferenceWindow); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (kwsAudioMFCCWindowSlider.HasNext()) { - const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next(); - std::vector<int16_t> kwsMfccAudioData = - std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize); - - /* Compute features for this window and write them to input tensor. */ - kwsMfccFeatureCalc(kwsMfccAudioData, - kwsAudioMFCCWindowSlider.Index(), - useCache, - kwsMfccVectorsInAudioStride); - } + preProcess.m_audioWindowIndex = audioDataSlider.Index(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, - audioDataSlider.TotalStrides() + 1); + /* Run the pre-processing, inference and post-processing. */ + if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + printf_err("KWS Pre-processing failed."); + return output; + } - /* Run inference over this audio clip sliding window. */ if (!RunInference(kwsModel, profiler)) { - printf_err("KWS inference failed\n"); + printf_err("KWS Inference failed."); return output; } - std::vector<ClassificationResult> kwsClassificationResult; - auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier"); + if (!postProcess.DoPostProcess()) { + printf_err("KWS Post-processing failed."); + return output; + } - kwsClassifier.GetClassificationResults( - kwsOutputTensor, kwsClassificationResult, - ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true); + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); - kwsResults.emplace_back( - kws::KwsResult( - kwsClassificationResult, - audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride, - audioDataSlider.Index(), kwsScoreThreshold) - ); + /* Add results from this window to our final results vector. */ + finalResults.emplace_back( + kws::KwsResult(singleInfResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride, + audioDataSlider.Index(), kwsScoreThreshold)); - /* Keyword detected. */ - if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) { - output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize; + /* Break out when trigger keyword is detected. */ + if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") + && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { + output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize; output.asrAudioSamples = get_audio_array_size(currentIndex) - (audioDataSlider.NextWindowStartIndex() - - kwsAudioDataStride + kwsAudioDataWindowSize); + preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize); break; } #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(kwsOutputTensor); + DumpTensor(kwsOutputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - if (!PresentInferenceResult(kwsResults)) { + if (!PresentInferenceResult(finalResults)) { return output; } @@ -271,41 +202,41 @@ namespace app { } /** - * @brief Performs the ASR pipeline. - * - * @param[in,out] ctx pointer to the application context object - * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin - * and how much data to process - * @return bool true if pipeline executed without failure - */ - static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { + * @brief Performs the ASR pipeline. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin + * and how much data to process. + * @return true if pipeline executed without failure. + **/ + static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) + { + auto& asrModel = ctx.Get<Model&>("asrModel"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength"); + auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride"); + auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold"); + auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - auto& profiler = ctx.Get<Profiler&>("profiler"); - hal_lcd_clear(COLOR_BLACK); - - /* Get model reference. */ - auto& asrModel = ctx.Get<Model&>("asrmodel"); if (!asrModel.IsInited()) { printf_err("ASR model has not been initialised\n"); return false; } - /* Get score threshold to be applied for the classifier (post-inference). */ - auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold"); + hal_lcd_clear(COLOR_BLACK); - /* Dimensions of the tensor should have been verified by the callee. */ + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); - const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - /* Populate ASR MFCC related parameters. */ - auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength"); - auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride"); + /* Get input shape. Dimensions of the tensor should have been verified by + * the callee. */ + TfLiteIntArray* inputShape = asrModel.GetInputShape(0); - /* Populate ASR inference context and inner lengths for input. */ - auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen"); + + const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx]; const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); /* Make sure the input tensor supports the above context and inner lengths. */ @@ -316,18 +247,9 @@ namespace app { } /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) * - asrMfccParamsWinStride + (asrMfccParamsWinLen); - const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride; - const float asrAudioParamsSecondsPerSample = - (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); - - /* Get pre/post-processing objects */ - auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess"); - auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess"); - - /* Set default reduction axis for post-processing. */ - const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); + const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride; + const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; /* Get the remaining audio buffer and respective size from KWS results. */ const int16_t* audioArr = kwsOutput.asrAudioStart; @@ -335,9 +257,9 @@ namespace app { /* Audio clip must have enough samples to produce 1 MFCC feature. */ std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize); - if (audioArrSize < asrMfccParamsWinLen) { + if (audioArrSize < asrMfccFrameLen) { printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n", - asrMfccParamsWinLen); + asrMfccFrameLen); return false; } @@ -345,26 +267,38 @@ namespace app { auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>( audioBuffer.data(), audioBuffer.size(), - asrAudioParamsWinLen, - asrAudioParamsWinStride); + asrAudioDataWindowLen, + asrAudioDataWindowStride); /* Declare a container for results. */ - std::vector<arm::app::asr::AsrResult> asrResults; + std::vector<asr::AsrResult> asrResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running ASR inference... "}; - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - size_t asrInferenceWindowLen = asrAudioParamsWinLen; - + size_t asrInferenceWindowLen = asrAudioDataWindowLen; + + /* Set up pre and post-processing objects. */ + AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + asrMfccFrameLen, asrMfccFrameStride); + + std::vector<ClassificationResult> singleInfResult; + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen); + AsrPostProcess asrPostProcess = AsrPostProcess( + asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"), + ctx.Get<std::vector<std::string>&>("asrLabels"), + singleInfResult, outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx + ); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { /* If not enough audio see how much can be sent for processing. */ size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); - if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) { + if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) { asrInferenceWindowLen = audioBuffer.size() - nextStartIndex; } @@ -373,8 +307,11 @@ namespace app { info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); - /* Calculate MFCCs, deltas and populate the input tensor. */ - asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor); + /* Run the pre-processing, inference and post-processing. */ + if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) { + printf_err("ASR pre-processing failed."); + return false; + } /* Run inference over this audio clip sliding window. */ if (!RunInference(asrModel, profiler)) { @@ -382,24 +319,28 @@ namespace app { return false; } - /* Post-process. */ - asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext()); + /* Post processing needs to know if we are on the last audio window. */ + asrPostProcess.m_lastIteration = !audioDataSlider.HasNext(); + if (!asrPostProcess.DoPostProcess()) { + printf_err("ASR post-processing failed."); + return false; + } /* Get results. */ std::vector<ClassificationResult> asrClassificationResult; - auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier"); + auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier"); asrClassifier.GetClassificationResults( asrOutputTensor, asrClassificationResult, - ctx.Get<std::vector<std::string>&>("asrlabels"), 1); + ctx.Get<std::vector<std::string>&>("asrLabels"), 1); asrResults.emplace_back(asr::AsrResult(asrClassificationResult, (audioDataSlider.Index() * asrAudioParamsSecondsPerSample * - asrAudioParamsWinStride), + asrAudioDataWindowStride), audioDataSlider.Index(), asrScoreThreshold)); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); + armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ /* Erase */ @@ -417,7 +358,7 @@ namespace app { return true; } - /* Audio inference classification handler. */ + /* KWS and ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { hal_lcd_clear(COLOR_BLACK); @@ -434,13 +375,14 @@ namespace app { do { KWSOutput kwsOutput = doKws(ctx); if (!kwsOutput.executionSuccess) { + printf_err("KWS failed\n"); return false; } if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { - info("Keyword spotted\n"); + info("Trigger keyword spotted\n"); if(!doAsr(ctx, kwsOutput)) { - printf_err("ASR failed"); + printf_err("ASR failed\n"); return false; } } @@ -452,7 +394,6 @@ namespace app { return true; } - static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -464,33 +405,31 @@ namespace app { /* Display each result. */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (uint32_t i = 0; i < results.size(); ++i) { - + for (auto & result : results) { std::string topKeyword{"<none>"}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { - topKeyword = results[i].m_resultVec[0].m_label; - score = results[i].m_resultVec[0].m_normalisedVal; + if (!result.m_resultVec.empty()) { + topKeyword = result.m_resultVec[0].m_label; + score = result.m_resultVec[0].m_normalisedVal; } std::string resultStr = - std::string{"@"} + std::to_string(results[i].m_timeStamp) + + std::string{"@"} + std::to_string(result.m_timeStamp) + std::string{"s: "} + topKeyword + std::string{" ("} + std::to_string(static_cast<int>(score * 100)) + std::string{"%)"}; - hal_lcd_display_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); + hal_lcd_display_text(resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); rowIdx1 += dataPsnTxtYIncr; info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n", - results[i].m_timeStamp, results[i].m_inferenceNumber, - results[i].m_threshold); - for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) { + result.m_timeStamp, result.m_inferenceNumber, + result.m_threshold); + for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j, - results[i].m_resultVec[j].m_label.c_str(), - results[i].m_resultVec[j].m_normalisedVal); + result.m_resultVec[j].m_label.c_str(), + result.m_resultVec[j].m_normalisedVal); } } @@ -523,143 +462,12 @@ namespace app { std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text( - finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); info("Final result: %s\n", finalResultStr.c_str()); return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - **/ - template<class T> - std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute) - { - /* Feature cache to be captured by lambda function. */ - static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); - - return [=](std::vector<int16_t>& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex) - { - T* tensorData = tflite::GetTensorData<T>(inputTensor); - std::vector<T> features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. - */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size(); - auto sizeBytes = sizeof(T) * size; - std::memcpy(tensorData + (index * size), features.data(), sizeBytes); - - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex) { - featureCache[index - featuresOverlapIndex] = std::move(features); - } - }; - } - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<int8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<uint8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - FeatureCalc<int16_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute); - - template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)> - FeatureCalc<float>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<float>(std::vector<int16_t>&)> compute); - - - static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) - { - std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto* quantParams = (TfLiteAffineQuantization*) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<int8_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteUInt8: { - mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteInt16: { - mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor, - cacheSize, - [=, &mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccComputeQuant<int16_t>(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - - - } else { - mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor, - cacheSize, - [&mfcc](std::vector<int16_t>& audioDataWindow) { - return mfcc.MfccCompute(audioDataWindow); - }); - } - return mfccFeatureCalc; - } } /* namespace app */ } /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc index 2a76b1b..42f434e 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,62 +15,71 @@ * limitations under the License. */ #include "Wav2LetterPostprocess.hpp" + #include "Wav2LetterModel.hpp" #include "log_macros.h" +#include <cmath> + namespace arm { namespace app { -namespace audio { -namespace asr { - - Postprocess::Postprocess(const uint32_t contextLen, - const uint32_t innerLen, - const uint32_t blankTokenIdx) - : m_contextLen(contextLen), - m_innerLen(innerLen), - m_totalLen(2 * this->m_contextLen + this->m_innerLen), + + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector<std::string>& labels, std::vector<ClassificationResult>& results, + const uint32_t outputContextLen, + const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx + ): + m_classifier(classifier), + m_outputTensor(outputTensor), + m_labels{labels}, + m_results(results), + m_outputContextLen(outputContextLen), m_countIterations(0), - m_blankTokenIdx(blankTokenIdx) - {} + m_blankTokenIdx(blankTokenIdx), + m_reductionAxisIdx(reductionAxisIdx) + { + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); + } - bool Postprocess::Invoke(TfLiteTensor* tensor, - const uint32_t axisIdx, - const bool lastIteration) + bool AsrPostProcess::DoPostProcess() { /* Basic checks. */ - if (!this->IsInputValid(tensor, axisIdx)) { + if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { return false; } /* Irrespective of tensor type, we use unsigned "byte" */ - uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor); - const uint32_t elemSz = this->GetTensorElementSize(tensor); + auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); /* Other sanity checks. */ if (0 == elemSz) { printf_err("Tensor type not supported for post processing\n"); return false; - } else if (elemSz * this->m_totalLen > tensor->bytes) { + } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) { printf_err("Insufficient number of tensor bytes\n"); return false; } /* Which axis do we need to process? */ - switch (axisIdx) { - case arm::app::Wav2LetterModel::ms_outputRowsIdx: - return this->EraseSectionsRowWise(ptrData, - elemSz * - tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx], - lastIteration); + switch (this->m_reductionAxisIdx) { + case Wav2LetterModel::ms_outputRowsIdx: + this->EraseSectionsRowWise( + ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx], + this->m_lastIteration); + break; default: - printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx); + printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx); + return false; } + this->m_classifier.GetClassificationResults(this->m_outputTensor, + this->m_results, this->m_labels, 1); - return false; + return true; } - bool Postprocess::IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const { if (nullptr == tensor) { return false; @@ -84,25 +93,23 @@ namespace asr { if (static_cast<int>(this->m_totalLen) != tensor->dims->data[axisIdx]) { - printf_err("Unexpected tensor dimension for axis %d, \n", - tensor->dims->data[axisIdx]); + printf_err("Unexpected tensor dimension for axis %d, got %d, \n", + axisIdx, tensor->dims->data[axisIdx]); return false; } return true; } - uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor) + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) { switch(tensor->type) { case kTfLiteUInt8: - return 1; case kTfLiteInt8: return 1; case kTfLiteInt16: return 2; case kTfLiteInt32: - return 4; case kTfLiteFloat32: return 4; default: @@ -113,30 +120,30 @@ namespace asr { return 0; } - bool Postprocess::EraseSectionsRowWise( - uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration) + bool AsrPostProcess::EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) { /* In this case, the "zero-ing" is quite simple as the region * to be zeroed sits in contiguous memory (row-major). */ - const uint32_t eraseLen = strideSzBytes * this->m_contextLen; + const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen; /* Erase left context? */ if (this->m_countIterations > 0) { /* Set output of each classification window to the blank token. */ std::memset(ptrData, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } /* Erase right context? */ if (false == lastIteration) { - uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen)); + uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen)); /* Set output of each classification window to the blank token. */ std::memset(rightCtxPtr, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } @@ -150,7 +157,58 @@ namespace asr { return true; } -} /* namespace asr */ -} /* namespace audio */ + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) + { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); + if (inputRows == 0) { + printf_err("Error getting number of input rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_inputRowsIdx); + } + return inputRows; + } + + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + { + const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + } + + /* Watching for underflow. */ + int innerLen = (outputRows - (2 * outputCtxLen)); + + return std::max(innerLen, 0); + } + + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + { + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + return 0; + } + + const float inOutRowRatio = static_cast<float>(inputRows) / + static_cast<float>(outputRows); + + return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio); + } + } /* namespace app */ } /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc index d3f3579..92b0631 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,41 +20,35 @@ #include "TensorFlowLiteMicro.hpp" #include <algorithm> -#include <math.h> +#include <cmath> namespace arm { namespace app { -namespace audio { -namespace asr { - - Preprocess::Preprocess( - const uint32_t numMfccFeatures, - const uint32_t windowLen, - const uint32_t windowStride, - const uint32_t numMfccVectors): - m_mfcc(numMfccFeatures, windowLen), - m_mfccBuf(numMfccFeatures, numMfccVectors), - m_delta1Buf(numMfccFeatures, numMfccVectors), - m_delta2Buf(numMfccFeatures, numMfccVectors), - m_windowLen(windowLen), - m_windowStride(windowStride), + + AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride + ): + m_mfcc(numMfccFeatures, mfccWindowLen), + m_inputTensor(inputTensor), + m_mfccBuf(numMfccFeatures, numFeatureFrames), + m_delta1Buf(numMfccFeatures, numFeatureFrames), + m_delta2Buf(numMfccFeatures, numFeatureFrames), + m_mfccWindowLen(mfccWindowLen), + m_mfccWindowStride(mfccWindowStride), m_numMfccFeats(numMfccFeatures), - m_numFeatVectors(numMfccVectors), - m_window() + m_numFeatureFrames(numFeatureFrames) { - if (numMfccFeatures > 0 && windowLen > 0) { + if (numMfccFeatures > 0 && mfccWindowLen > 0) { this->m_mfcc.Init(); } } - bool Preprocess::Invoke( - const int16_t* audioData, - const uint32_t audioDataLen, - TfLiteTensor* tensor) + bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) { - this->m_window = SlidingWindow<const int16_t>( - audioData, audioDataLen, - this->m_windowLen, this->m_windowStride); + this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>( + static_cast<const int16_t*>(audioData), audioDataLen, + this->m_mfccWindowLen, this->m_mfccWindowStride); uint32_t mfccBufIdx = 0; @@ -62,12 +56,12 @@ namespace asr { std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f); std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f); - /* While we can slide over the window. */ - while (this->m_window.HasNext()) { - const int16_t* mfccWindow = this->m_window.Next(); + /* While we can slide over the audio. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); auto mfccAudioData = std::vector<int16_t>( mfccWindow, - mfccWindow + this->m_windowLen); + mfccWindow + this->m_mfccWindowLen); auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData); for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) { this->m_mfccBuf(i, mfccBufIdx) = mfcc[i]; @@ -76,11 +70,11 @@ namespace asr { } /* Pad MFCC if needed by adding MFCC for zeros. */ - if (mfccBufIdx != this->m_numFeatVectors) { - std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_windowLen, 0); + if (mfccBufIdx != this->m_numFeatureFrames) { + std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_mfccWindowLen, 0); std::vector<float> mfccZeros = this->m_mfcc.MfccCompute(zerosWindow); - while (mfccBufIdx != this->m_numFeatVectors) { + while (mfccBufIdx != this->m_numFeatureFrames) { memcpy(&this->m_mfccBuf(0, mfccBufIdx), mfccZeros.data(), sizeof(float) * m_numMfccFeats); ++mfccBufIdx; @@ -88,41 +82,39 @@ namespace asr { } /* Compute first and second order deltas from MFCCs. */ - this->ComputeDeltas(this->m_mfccBuf, - this->m_delta1Buf, - this->m_delta2Buf); + AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); - /* Normalise. */ - this->Normalise(); + /* Standardize calculated features. */ + this->Standarize(); /* Quantise. */ - QuantParams quantParams = GetTensorQuantParams(tensor); + QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor); if (0 == quantParams.scale) { printf_err("Quantisation scale can't be 0\n"); return false; } - switch(tensor->type) { + switch(this->m_inputTensor->type) { case kTfLiteUInt8: return this->Quantise<uint8_t>( - tflite::GetTensorData<uint8_t>(tensor), tensor->bytes, + tflite::GetTensorData<uint8_t>(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); case kTfLiteInt8: return this->Quantise<int8_t>( - tflite::GetTensorData<int8_t>(tensor), tensor->bytes, + tflite::GetTensorData<int8_t>(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); default: printf_err("Unsupported tensor type %s\n", - TfLiteTypeGetName(tensor->type)); + TfLiteTypeGetName(this->m_inputTensor->type)); } return false; } - bool Preprocess::ComputeDeltas(Array2d<float>& mfcc, - Array2d<float>& delta1, - Array2d<float>& delta2) + bool AsrPreProcess::ComputeDeltas(Array2d<float>& mfcc, + Array2d<float>& delta1, + Array2d<float>& delta2) { const std::vector <float> delta1Coeffs = {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, @@ -148,11 +140,11 @@ namespace asr { /* Iterate through features in MFCC vector. */ for (size_t i = 0; i < numFeatures; ++i) { /* For each feature, iterate through time (t) samples representing feature evolution and - * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels. + * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels. * Convolution padding = valid, result size is `time length - kernel length + 1`. * The result is padded with 0 from both sides to match the size of initial time samples data. * - * For the small filter, conv1d implementation as a simple loop is efficient enough. + * For the small filter, conv1D implementation as a simple loop is efficient enough. * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32. */ @@ -175,20 +167,10 @@ namespace asr { return true; } - float Preprocess::GetMean(Array2d<float>& vec) - { - return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); - } - - float Preprocess::GetStdDev(Array2d<float>& vec, const float mean) - { - return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); - } - - void Preprocess::NormaliseVec(Array2d<float>& vec) + void AsrPreProcess::StandardizeVecF32(Array2d<float>& vec) { - auto mean = Preprocess::GetMean(vec); - auto stddev = Preprocess::GetStdDev(vec, mean); + auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); debug("Mean: %f, Stddev: %f\n", mean, stddev); if (stddev == 0) { @@ -204,14 +186,14 @@ namespace asr { } } - void Preprocess::Normalise() + void AsrPreProcess::Standarize() { - Preprocess::NormaliseVec(this->m_mfccBuf); - Preprocess::NormaliseVec(this->m_delta1Buf); - Preprocess::NormaliseVec(this->m_delta2Buf); + AsrPreProcess::StandardizeVecF32(this->m_mfccBuf); + AsrPreProcess::StandardizeVecF32(this->m_delta1Buf); + AsrPreProcess::StandardizeVecF32(this->m_delta2Buf); } - float Preprocess::GetQuantElem( + float AsrPreProcess::GetQuantElem( const float elem, const float quantScale, const int quantOffset, @@ -222,7 +204,5 @@ namespace asr { return std::min<float>(std::max<float>(val, minVal), maxVal); } -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake index b3fe020..40df4d7 100644 --- a/source/use_case/kws_asr/usecase.cmake +++ b/source/use_case/kws_asr/usecase.cmake @@ -1,5 +1,5 @@ #---------------------------------------------------------------------------- -# Copyright (c) 2021 Arm Limited. All rights reserved. +# Copyright (c) 2021-2022 Arm Limited. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -59,7 +59,7 @@ USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_KWS "Specify the score threshold [0.0, 1.0) that must be applied to the KWS results for a label to be deemed valid." - 0.9 + 0.7 STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [0.0, 1.0) that must be applied to the ASR results for a label to be deemed valid." diff --git a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp index c188e42..cbf0e4e 100644 --- a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp +++ b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef RNNOISE_FEATURE_PROCESSOR_HPP +#define RNNOISE_FEATURE_PROCESSOR_HPP + #include "PlatformMath.hpp" #include <cstdint> #include <vector> @@ -47,11 +50,11 @@ namespace rnn { * - https://jmvalin.ca/demo/rnnoise/ * - https://arxiv.org/abs/1709.08243 **/ - class RNNoiseProcess { + class RNNoiseFeatureProcessor { /* Public interface */ public: - RNNoiseProcess(); - ~RNNoiseProcess() = default; + RNNoiseFeatureProcessor(); + ~RNNoiseFeatureProcessor() = default; /** * @brief Calculates the features from a given audio buffer ready to be sent to RNNoise model. @@ -328,10 +331,11 @@ namespace rnn { const std::array <uint32_t, NB_BANDS> m_eband5ms { 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100}; - }; } /* namespace rnn */ -} /* namspace app */ +} /* namespace app */ } /* namespace arm */ + +#endif /* RNNOISE_FEATURE_PROCESSOR_HPP */ diff --git a/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp new file mode 100644 index 0000000..15e62d9 --- /dev/null +++ b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RNNOISE_PROCESSING_HPP +#define RNNOISE_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "RNNoiseFeatureProcessor.hpp" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Noise Reduction use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class RNNoisePreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in/out] featureProcessor RNNoise specific feature extractor object. + * @param[in/out] frameFeatures RNNoise specific features shared between pre & post-processing. + * + **/ + explicit RNNoisePreProcess(TfLiteTensor* inputTensor, + std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor, + std::shared_ptr<rnn::FrameFeatures> frameFeatures); + + /** + * @brief Should perform pre-processing of 'raw' input audio data and load it into + * TFLite Micro input tensors ready for inference + * @param[in] input Pointer to the data that pre-processing will work on. + * @param[in] inputSize Size of the input data. + * @return true if successful, false otherwise. + **/ + bool DoPreProcess(const void* input, size_t inputSize) override; + + private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ + std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor; /* RNNoise feature processor shared between pre & post-processing. */ + std::shared_ptr<rnn::FrameFeatures> m_frameFeatures; /* RNNoise features shared between pre & post-processing. */ + rnn::vec1D32F m_audioFrame; /* Audio frame cast to FP32 */ + + /** + * @brief Quantize the given features and populate the input Tensor. + * @param[in] inputFeatures Vector of floating point features to quantize. + * @param[in] quantScale Quantization scale for the inputTensor. + * @param[in] quantOffset Quantization offset for the inputTensor. + * @param[in,out] inputTensor TFLite micro tensor to populate. + **/ + static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, + float quantScale, int quantOffset, + TfLiteTensor* inputTensor); + }; + + /** + * @brief Post-processing class for Noise Reduction use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class RNNoisePostProcess : public BasePostProcess { + + public: + /** + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[out] denoisedAudioFrame Vector to store the final denoised audio frame. + * @param[in/out] featureProcessor RNNoise specific feature extractor object. + * @param[in/out] frameFeatures RNNoise specific features shared between pre & post-processing. + **/ + RNNoisePostProcess(TfLiteTensor* outputTensor, + std::vector<int16_t>& denoisedAudioFrame, + std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor, + std::shared_ptr<rnn::FrameFeatures> frameFeatures); + + /** + * @brief Should perform post-processing of the result of inference then + * populate result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + private: + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + std::vector<int16_t>& m_denoisedAudioFrame; /* Vector to store the final denoised frame. */ + rnn::vec1D32F m_denoisedAudioFrameFloat; /* Internal vector to store the final denoised frame (FP32). */ + std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor; /* RNNoise feature processor shared between pre & post-processing. */ + std::shared_ptr<rnn::FrameFeatures> m_frameFeatures; /* RNNoise features shared between pre & post-processing. */ + std::vector<float> m_modelOutputFloat; /* Internal vector to store de-quantized model output. */ + + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* RNNOISE_PROCESSING_HPP */
\ No newline at end of file diff --git a/source/use_case/noise_reduction/src/MainLoop.cc b/source/use_case/noise_reduction/src/MainLoop.cc index 5fd7823..fd72127 100644 --- a/source/use_case/noise_reduction/src/MainLoop.cc +++ b/source/use_case/noise_reduction/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,12 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "RNNoiseModel.hpp" /* Model class for running inference. */ #include "InputFiles.hpp" /* For input audio clips. */ -#include "RNNoiseProcess.hpp" /* Pre-processing class */ #include "log_macros.h" enum opcodes diff --git a/source/use_case/noise_reduction/src/RNNoiseProcess.cc b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc index 4c568fa..036894c 100644 --- a/source/use_case/noise_reduction/src/RNNoiseProcess.cc +++ b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "RNNoiseProcess.hpp" +#include "RNNoiseFeatureProcessor.hpp" #include "log_macros.h" #include <algorithm> @@ -33,7 +33,7 @@ do { \ } \ } while(0) -RNNoiseProcess::RNNoiseProcess() : +RNNoiseFeatureProcessor::RNNoiseFeatureProcessor() : m_halfWindow(FRAME_SIZE, 0), m_dctTable(NB_BANDS * NB_BANDS), m_analysisMem(FRAME_SIZE, 0), @@ -54,9 +54,9 @@ RNNoiseProcess::RNNoiseProcess() : this->InitTables(); } -void RNNoiseProcess::PreprocessFrame(const float* audioData, - const size_t audioLen, - FrameFeatures& features) +void RNNoiseFeatureProcessor::PreprocessFrame(const float* audioData, + const size_t audioLen, + FrameFeatures& features) { /* Note audioWindow is modified in place */ const arrHp aHp {-1.99599, 0.99600 }; @@ -68,7 +68,7 @@ void RNNoiseProcess::PreprocessFrame(const float* audioData, this->ComputeFrameFeatures(audioWindow, features); } -void RNNoiseProcess::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame) +void RNNoiseFeatureProcessor::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame) { std::vector<float> outputBands = modelOutput; std::vector<float> gain(FREQ_SIZE, 0); @@ -92,7 +92,7 @@ void RNNoiseProcess::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& feat FrameSynthesis(outFrame, features.m_fftX); } -void RNNoiseProcess::InitTables() +void RNNoiseFeatureProcessor::InitTables() { constexpr float pi = M_PI; constexpr float halfPi = M_PI / 2; @@ -111,7 +111,7 @@ void RNNoiseProcess::InitTables() } } -void RNNoiseProcess::BiQuad( +void RNNoiseFeatureProcessor::BiQuad( const arrHp& bHp, const arrHp& aHp, arrHp& memHpX, @@ -126,8 +126,8 @@ void RNNoiseProcess::BiQuad( } } -void RNNoiseProcess::ComputeFrameFeatures(vec1D32F& audioWindow, - FrameFeatures& features) +void RNNoiseFeatureProcessor::ComputeFrameFeatures(vec1D32F& audioWindow, + FrameFeatures& features) { this->FrameAnalysis(audioWindow, features.m_fftX, @@ -264,7 +264,7 @@ void RNNoiseProcess::ComputeFrameFeatures(vec1D32F& audioWindow, features.m_featuresVec[NB_BANDS + 3 * NB_DELTA_CEPS + 1] = specVariability / CEPS_MEM - 2.1; } -void RNNoiseProcess::FrameAnalysis( +void RNNoiseFeatureProcessor::FrameAnalysis( const vec1D32F& audioWindow, vec1D32F& fft, vec1D32F& energy, @@ -289,7 +289,7 @@ void RNNoiseProcess::FrameAnalysis( ComputeBandEnergy(fft, energy); } -void RNNoiseProcess::ApplyWindow(vec1D32F& x) +void RNNoiseFeatureProcessor::ApplyWindow(vec1D32F& x) { if (WINDOW_SIZE != x.size()) { printf_err("Invalid size for vector to be windowed\n"); @@ -305,7 +305,7 @@ void RNNoiseProcess::ApplyWindow(vec1D32F& x) } } -void RNNoiseProcess::ForwardTransform( +void RNNoiseFeatureProcessor::ForwardTransform( vec1D32F& x, vec1D32F& fft) { @@ -327,7 +327,7 @@ void RNNoiseProcess::ForwardTransform( * first half of the FFT's. The conjugates are not present. */ } -void RNNoiseProcess::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE) +void RNNoiseFeatureProcessor::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE) { bandE = vec1D32F(NB_BANDS, 0); @@ -351,7 +351,7 @@ void RNNoiseProcess::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE) bandE[NB_BANDS - 1] *= 2; } -void RNNoiseProcess::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC) +void RNNoiseFeatureProcessor::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC) { bandC = vec1D32F(NB_BANDS, 0); VERIFY(this->m_eband5ms.size() >= NB_BANDS); @@ -374,7 +374,7 @@ void RNNoiseProcess::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D bandC[NB_BANDS - 1] *= 2; } -void RNNoiseProcess::DCT(vec1D32F& input, vec1D32F& output) +void RNNoiseFeatureProcessor::DCT(vec1D32F& input, vec1D32F& output) { VERIFY(this->m_dctTable.size() >= NB_BANDS * NB_BANDS); for (uint32_t i = 0; i < NB_BANDS; ++i) { @@ -387,7 +387,7 @@ void RNNoiseProcess::DCT(vec1D32F& input, vec1D32F& output) } } -void RNNoiseProcess::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) { +void RNNoiseFeatureProcessor::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) { for (size_t i = 1; i < (pitchBufSz >> 1); ++i) { pitchBuf[i] = 0.5 * ( 0.5 * (this->m_pitchBuf[2 * i - 1] + this->m_pitchBuf[2 * i + 1]) @@ -431,7 +431,7 @@ void RNNoiseProcess::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) { this->Fir5(lpc2, pitchBufSz >> 1, pitchBuf); } -int RNNoiseProcess::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) { +int RNNoiseFeatureProcessor::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) { uint32_t lag = len + maxPitch; vec1D32F xLp4(len >> 2, 0); vec1D32F yLp4(lag >> 2, 0); @@ -488,7 +488,7 @@ int RNNoiseProcess::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32 return 2*bestPitch[0] - offset; } -arrHp RNNoiseProcess::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch) +arrHp RNNoiseFeatureProcessor::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch) { float Syy = 1; arrHp bestNum {-1, -1}; @@ -527,7 +527,7 @@ arrHp RNNoiseProcess::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, return bestPitch; } -int RNNoiseProcess::RemoveDoubling( +int RNNoiseFeatureProcessor::RemoveDoubling( vec1D32F& pitchBuf, uint32_t maxPeriod, uint32_t minPeriod, @@ -679,12 +679,12 @@ int RNNoiseProcess::RemoveDoubling( return this->m_lastPeriod; } -float RNNoiseProcess::ComputePitchGain(float xy, float xx, float yy) +float RNNoiseFeatureProcessor::ComputePitchGain(float xy, float xx, float yy) { return xy / math::MathUtils::SqrtF32(1+xx*yy); } -void RNNoiseProcess::AutoCorr( +void RNNoiseFeatureProcessor::AutoCorr( const vec1D32F& x, vec1D32F& ac, size_t lag, @@ -711,7 +711,7 @@ void RNNoiseProcess::AutoCorr( } -void RNNoiseProcess::PitchXCorr( +void RNNoiseFeatureProcessor::PitchXCorr( const vec1D32F& x, const vec1D32F& y, vec1D32F& xCorr, @@ -728,7 +728,7 @@ void RNNoiseProcess::PitchXCorr( } /* Linear predictor coefficients */ -void RNNoiseProcess::LPC( +void RNNoiseFeatureProcessor::LPC( const vec1D32F& correlation, int32_t p, vec1D32F& lpc) @@ -766,7 +766,7 @@ void RNNoiseProcess::LPC( } } -void RNNoiseProcess::Fir5( +void RNNoiseFeatureProcessor::Fir5( const vec1D32F &num, uint32_t N, vec1D32F &x) @@ -794,7 +794,7 @@ void RNNoiseProcess::Fir5( } } -void RNNoiseProcess::PitchFilter(FrameFeatures &features, vec1D32F &gain) { +void RNNoiseFeatureProcessor::PitchFilter(FrameFeatures &features, vec1D32F &gain) { std::vector<float> r(NB_BANDS, 0); std::vector<float> rf(FREQ_SIZE, 0); std::vector<float> newE(NB_BANDS); @@ -835,7 +835,7 @@ void RNNoiseProcess::PitchFilter(FrameFeatures &features, vec1D32F &gain) { } } -void RNNoiseProcess::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) { +void RNNoiseFeatureProcessor::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) { std::vector<float> x(WINDOW_SIZE, 0); InverseTransform(x, fftY); ApplyWindow(x); @@ -845,7 +845,7 @@ void RNNoiseProcess::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) { memcpy((m_synthesisMem.data()), &x[FRAME_SIZE], FRAME_SIZE*sizeof(float)); } -void RNNoiseProcess::InterpBandGain(vec1D32F& g, vec1D32F& bandE) { +void RNNoiseFeatureProcessor::InterpBandGain(vec1D32F& g, vec1D32F& bandE) { for (size_t i = 0; i < NB_BANDS - 1; i++) { int bandSize = (m_eband5ms[i + 1] - m_eband5ms[i]) << FRAME_SIZE_SHIFT; for (int j = 0; j < bandSize; j++) { @@ -855,7 +855,7 @@ void RNNoiseProcess::InterpBandGain(vec1D32F& g, vec1D32F& bandE) { } } -void RNNoiseProcess::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) { +void RNNoiseFeatureProcessor::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) { std::vector<float> x(WINDOW_SIZE * 2); /* This is complex. */ vec1D32F newFFT; /* This is complex. */ diff --git a/source/use_case/noise_reduction/src/RNNoiseProcessing.cc b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc new file mode 100644 index 0000000..f6a3ec4 --- /dev/null +++ b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RNNoiseProcessing.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + RNNoisePreProcess::RNNoisePreProcess(TfLiteTensor* inputTensor, + std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor, std::shared_ptr<rnn::FrameFeatures> frameFeatures) + : m_inputTensor{inputTensor}, + m_featureProcessor{featureProcessor}, + m_frameFeatures{frameFeatures} + {} + + bool RNNoisePreProcess::DoPreProcess(const void* data, size_t inputSize) + { + if (data == nullptr) { + printf_err("Data pointer is null"); + return false; + } + + auto input = static_cast<const int16_t*>(data); + this->m_audioFrame = rnn::vec1D32F(input, input + inputSize); + m_featureProcessor->PreprocessFrame(this->m_audioFrame.data(), inputSize, *this->m_frameFeatures); + + QuantizeAndPopulateInput(this->m_frameFeatures->m_featuresVec, + this->m_inputTensor->params.scale, this->m_inputTensor->params.zero_point, + this->m_inputTensor); + + debug("Input tensor populated \n"); + + return true; + } + + void RNNoisePreProcess::QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, + const float quantScale, const int quantOffset, + TfLiteTensor* inputTensor) + { + const float minVal = std::numeric_limits<int8_t>::min(); + const float maxVal = std::numeric_limits<int8_t>::max(); + + auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor); + + for (size_t i=0; i < inputFeatures.size(); ++i) { + float quantValue = ((inputFeatures[i] / quantScale) + quantOffset); + inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal)); + } + } + + RNNoisePostProcess::RNNoisePostProcess(TfLiteTensor* outputTensor, + std::vector<int16_t>& denoisedAudioFrame, + std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor, + std::shared_ptr<rnn::FrameFeatures> frameFeatures) + : m_outputTensor{outputTensor}, + m_denoisedAudioFrame{denoisedAudioFrame}, + m_featureProcessor{featureProcessor}, + m_frameFeatures{frameFeatures} + { + this->m_denoisedAudioFrameFloat.reserve(denoisedAudioFrame.size()); + this->m_modelOutputFloat.resize(outputTensor->bytes); + } + + bool RNNoisePostProcess::DoPostProcess() + { + const auto* outputData = tflite::GetTensorData<int8_t>(this->m_outputTensor); + auto outputQuantParams = GetTensorQuantParams(this->m_outputTensor); + + for (size_t i = 0; i < this->m_outputTensor->bytes; ++i) { + this->m_modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset) + * outputQuantParams.scale; + } + + this->m_featureProcessor->PostProcessFrame(this->m_modelOutputFloat, + *this->m_frameFeatures, this->m_denoisedAudioFrameFloat); + + for (size_t i = 0; i < this->m_denoisedAudioFrame.size(); ++i) { + this->m_denoisedAudioFrame[i] = static_cast<int16_t>( + std::roundf(this->m_denoisedAudioFrameFloat[i])); + } + + return true; + } + +} /* namespace app */ +} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc index acb8ba7..53bb43e 100644 --- a/source/use_case/noise_reduction/src/UseCaseHandler.cc +++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc @@ -21,12 +21,10 @@ #include "ImageUtils.hpp" #include "InputFiles.hpp" #include "RNNoiseModel.hpp" -#include "RNNoiseProcess.hpp" +#include "RNNoiseFeatureProcessor.hpp" +#include "RNNoiseProcessing.hpp" #include "log_macros.h" -#include <cmath> -#include <algorithm> - namespace arm { namespace app { @@ -36,17 +34,6 @@ namespace app { **/ static void IncrementAppCtxClipIdx(ApplicationContext& ctx); - /** - * @brief Quantize the given features and populate the input Tensor. - * @param[in] inputFeatures Vector of floating point features to quantize. - * @param[in] quantScale Quantization scale for the inputTensor. - * @param[in] quantOffset Quantization offset for the inputTensor. - * @param[in,out] inputTensor TFLite micro tensor to populate. - **/ - static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, - float quantScale, int quantOffset, - TfLiteTensor* inputTensor); - /* Noise reduction inference handler. */ bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll) { @@ -57,7 +44,7 @@ namespace app { size_t memDumpMaxLen = 0; uint8_t* memDumpBaseAddr = nullptr; size_t undefMemDumpBytesWritten = 0; - size_t *pMemDumpBytesWritten = &undefMemDumpBytesWritten; + size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten; if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) { memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN"); memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR"); @@ -74,8 +61,8 @@ namespace app { } /* Populate Pre-Processing related parameters. */ - auto audioParamsWinLen = ctx.Get<uint32_t>("frameLength"); - auto audioParamsWinStride = ctx.Get<uint32_t>("frameStride"); + auto audioFrameLen = ctx.Get<uint32_t>("frameLength"); + auto audioFrameStride = ctx.Get<uint32_t>("frameStride"); auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures"); TfLiteTensor* inputTensor = model.GetInputTensor(0); @@ -103,7 +90,7 @@ namespace app { if (ctx.Has("featureFileNames")) { audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames"); } - do{ + do { hal_lcd_clear(COLOR_BLACK); auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten; @@ -112,32 +99,38 @@ namespace app { /* Creating a sliding window through the audio. */ auto audioDataSlider = audio::SlidingWindow<const int16_t>( audioAccessorFunc(currentIndex), - audioSizeAccessorFunc(currentIndex), audioParamsWinLen, - audioParamsWinStride); + audioSizeAccessorFunc(currentIndex), audioFrameLen, + audioFrameStride); info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex, audioFileAccessorFunc(currentIndex)); memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex), - (audioDataSlider.TotalStrides() + 1) * audioParamsWinLen, + (audioDataSlider.TotalStrides() + 1) * audioFrameLen, memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); - rnn::RNNoiseProcess featureProcessor = rnn::RNNoiseProcess(); - rnn::vec1D32F audioFrame(audioParamsWinLen); - rnn::vec1D32F inputFeatures(nrNumInputFeatures); - rnn::vec1D32F denoisedAudioFrameFloat(audioParamsWinLen); - std::vector<int16_t> denoisedAudioFrame(audioParamsWinLen); + /* Set up pre and post-processing. */ + std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor = + std::make_shared<rnn::RNNoiseFeatureProcessor>(); + std::shared_ptr<rnn::FrameFeatures> frameFeatures = + std::make_shared<rnn::FrameFeatures>(); + + RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures); + + std::vector<int16_t> denoisedAudioFrame(audioFrameLen); + RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame, + featureProcessor, frameFeatures); - std::vector<float> modelOutputFloat(outputTensor->bytes); - rnn::FrameFeatures frameFeatures; bool resetGRU = true; while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - audioFrame = rnn::vec1D32F(inferenceWindow, inferenceWindow+audioParamsWinLen); - featureProcessor.PreprocessFrame(audioFrame.data(), audioParamsWinLen, frameFeatures); + if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) { + printf_err("Pre-processing failed."); + return false; + } /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */ if (resetGRU){ @@ -148,53 +141,35 @@ namespace app { model.CopyGruStates(); } - QuantizeAndPopulateInput(frameFeatures.m_featuresVec, - inputTensor->params.scale, inputTensor->params.zero_point, - inputTensor); - /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; /* Display message on the LCD - inference running. */ - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run inference over this feature sliding window. */ - profiler.StartProfiling("Inference"); - bool success = model.RunInference(); - profiler.StopProfiling(); - resetGRU = false; - - if (!success) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } + resetGRU = false; - /* De-quantize main model output ready for post-processing. */ - const auto* outputData = tflite::GetTensorData<int8_t>(outputTensor); - auto outputQuantParams = arm::app::GetTensorQuantParams(outputTensor); - - for (size_t i = 0; i < outputTensor->bytes; ++i) { - modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset) - * outputQuantParams.scale; - } - - /* Round and cast the post-processed results for dumping to wav. */ - featureProcessor.PostProcessFrame(modelOutputFloat, frameFeatures, denoisedAudioFrameFloat); - for (size_t i = 0; i < audioParamsWinLen; ++i) { - denoisedAudioFrame[i] = static_cast<int16_t>(std::roundf(denoisedAudioFrameFloat[i])); + /* Carry out post-processing. */ + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); + return false; } /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); if (memDumpMaxLen > 0) { - /* Dump output tensors to memory. */ + /* Dump final post processed output to memory. */ memDumpBytesWritten += DumpOutputDenoisedAudioFrame( denoisedAudioFrame, memDumpBaseAddr + memDumpBytesWritten, @@ -209,6 +184,7 @@ namespace app { valMemDumpBytesWritten, startDumpAddress); } + /* Finish by dumping the footer. */ DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten); info("All inferences for audio clip complete.\n"); @@ -216,15 +192,13 @@ namespace app { IncrementAppCtxClipIdx(ctx); std::string clearString{' '}; - hal_lcd_display_text( - clearString.c_str(), clearString.size(), + hal_lcd_display_text(clearString.c_str(), clearString.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); std::string completeMsg{"Inference complete!"}; /* Display message on the LCD - inference complete. */ - hal_lcd_display_text( - completeMsg.c_str(), completeMsg.size(), + hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); @@ -233,7 +207,7 @@ namespace app { } size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize, - uint8_t *memAddress, size_t memSize){ + uint8_t* memAddress, size_t memSize){ if (memAddress == nullptr){ return 0; @@ -284,7 +258,7 @@ namespace app { return numBytesWritten; } - size_t DumpDenoisedAudioFooter(uint8_t *memAddress, size_t memSize){ + size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){ if ((memAddress == nullptr) || (memSize < 4)) { return 0; } @@ -294,8 +268,8 @@ namespace app { return sizeof(int32_t); } - size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t> &audioFrame, - uint8_t *memAddress, size_t memSize) + size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame, + uint8_t* memAddress, size_t memSize) { if (memAddress == nullptr) { return 0; @@ -324,7 +298,7 @@ namespace app { const TfLiteTensor* tensor = model.GetOutputTensor(i); const auto* tData = tflite::GetTensorData<uint8_t>(tensor); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(tensor); + DumpTensor(tensor); #endif /* VERIFY_TEST_OUTPUT */ /* Ensure that we don't overflow the allowed limit. */ if (numBytesWritten + tensor->bytes <= memSize) { @@ -360,20 +334,5 @@ namespace app { ctx.Set<uint32_t>("clipIndex", curClipIdx); } - void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures, - const float quantScale, const int quantOffset, TfLiteTensor* inputTensor) - { - const float minVal = std::numeric_limits<int8_t>::min(); - const float maxVal = std::numeric_limits<int8_t>::max(); - - auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor); - - for (size_t i=0; i < inputFeatures.size(); ++i) { - float quantValue = ((inputFeatures[i] / quantScale) + quantOffset); - inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal)); - } - } - - } /* namespace app */ } /* namespace arm */ |