summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-05-04 09:45:02 +0100
committerRichard Burton <richard.burton@arm.com>2022-05-04 09:45:02 +0100
commit4e002791bc6781b549c6951cfe44f918289d7e82 (patch)
treeb639243b5fa433657c207783a384bad1ed248536
parentdd6d07b24bbf9023ebe8e8927be8aac3291d0f58 (diff)
downloadml-embedded-evaluation-kit-4e002791bc6781b549c6951cfe44f918289d7e82.tar.gz
MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's
Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37
-rw-r--r--source/application/main/include/BaseProcessing.hpp6
-rw-r--r--source/use_case/ad/include/AdModel.hpp8
-rw-r--r--source/use_case/ad/include/AdPostProcessing.hpp50
-rw-r--r--source/use_case/ad/include/AdProcessing.hpp230
-rw-r--r--source/use_case/ad/src/AdPostProcessing.cc115
-rw-r--r--source/use_case/ad/src/AdProcessing.cc208
-rw-r--r--source/use_case/ad/src/MainLoop.cc7
-rw-r--r--source/use_case/ad/src/UseCaseHandler.cc291
-rw-r--r--source/use_case/asr/include/Wav2LetterModel.hpp2
-rw-r--r--source/use_case/kws_asr/include/KwsProcessing.hpp138
-rw-r--r--source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp10
-rw-r--r--source/use_case/kws_asr/include/Wav2LetterModel.hpp12
-rw-r--r--source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp117
-rw-r--r--source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp138
-rw-r--r--source/use_case/kws_asr/src/KwsProcessing.cc212
-rw-r--r--source/use_case/kws_asr/src/MainLoop.cc125
-rw-r--r--source/use_case/kws_asr/src/UseCaseHandler.cc492
-rw-r--r--source/use_case/kws_asr/src/Wav2LetterPostprocess.cc146
-rw-r--r--source/use_case/kws_asr/src/Wav2LetterPreprocess.cc116
-rw-r--r--source/use_case/kws_asr/usecase.cmake4
-rw-r--r--source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp (renamed from source/use_case/noise_reduction/include/RNNoiseProcess.hpp)16
-rw-r--r--source/use_case/noise_reduction/include/RNNoiseProcessing.hpp113
-rw-r--r--source/use_case/noise_reduction/src/MainLoop.cc4
-rw-r--r--source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc (renamed from source/use_case/noise_reduction/src/RNNoiseProcess.cc)60
-rw-r--r--source/use_case/noise_reduction/src/RNNoiseProcessing.cc100
-rw-r--r--source/use_case/noise_reduction/src/UseCaseHandler.cc129
-rw-r--r--tests/use_case/ad/PostProcessTests.cc53
-rw-r--r--tests/use_case/kws_asr/MfccTests.cc8
-rw-r--r--tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc142
-rw-r--r--tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc126
-rw-r--r--tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp8
31 files changed, 1786 insertions, 1400 deletions
diff --git a/source/application/main/include/BaseProcessing.hpp b/source/application/main/include/BaseProcessing.hpp
index c1c3255..c099db2 100644
--- a/source/application/main/include/BaseProcessing.hpp
+++ b/source/application/main/include/BaseProcessing.hpp
@@ -41,9 +41,6 @@ namespace app {
* @return true if successful, false otherwise.
**/
virtual bool DoPreProcess(const void* input, size_t inputSize) = 0;
-
- protected:
- Model* m_model = nullptr;
};
/**
@@ -62,9 +59,6 @@ namespace app {
* @return true if successful, false otherwise.
**/
virtual bool DoPostProcess() = 0;
-
- protected:
- Model* m_model = nullptr;
};
} /* namespace app */
diff --git a/source/use_case/ad/include/AdModel.hpp b/source/use_case/ad/include/AdModel.hpp
index 8d914c4..2195a7c 100644
--- a/source/use_case/ad/include/AdModel.hpp
+++ b/source/use_case/ad/include/AdModel.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,6 +28,12 @@ namespace arm {
namespace app {
class AdModel : public Model {
+
+ public:
+ /* Indices for the expected model - based on input tensor shape */
+ static constexpr uint32_t ms_inputRowsIdx = 1;
+ static constexpr uint32_t ms_inputColsIdx = 2;
+
protected:
/** @brief Gets the reference to op resolver interface class */
const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/ad/include/AdPostProcessing.hpp b/source/use_case/ad/include/AdPostProcessing.hpp
deleted file mode 100644
index 7eaec84..0000000
--- a/source/use_case/ad/include/AdPostProcessing.hpp
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef ADPOSTPROCESSING_HPP
-#define ADPOSTPROCESSING_HPP
-
-#include "TensorFlowLiteMicro.hpp"
-
-#include <vector>
-
-namespace arm {
-namespace app {
-
- /** @brief Dequantize TensorFlow Lite Micro tensor.
- * @param[in] tensor Pointer to the TensorFlow Lite Micro tensor to be dequantized.
- * @return Vector with the dequantized tensor values.
- **/
- template<typename T>
- std::vector<float> Dequantize(TfLiteTensor* tensor);
-
- /**
- * @brief Calculates the softmax of vector in place. **/
- void Softmax(std::vector<float>& inputVector);
-
-
- /** @brief Given a wav file name return AD model output index.
- * @param[in] wavFileName Audio WAV filename.
- * File name should be in format anything_goes_XX_here.wav
- * where XX is the machine ID e.g. 00, 02, 04 or 06
- * @return AD model output index as 8 bit integer.
- **/
- int8_t OutputIndexFromFileName(std::string wavFileName);
-
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* ADPOSTPROCESSING_HPP */
diff --git a/source/use_case/ad/include/AdProcessing.hpp b/source/use_case/ad/include/AdProcessing.hpp
new file mode 100644
index 0000000..9abf6f1
--- /dev/null
+++ b/source/use_case/ad/include/AdProcessing.hpp
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef AD_PROCESSING_HPP
+#define AD_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "AudioUtils.hpp"
+#include "AdMelSpectrogram.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+ /**
+ * @brief Pre-processing class for anomaly detection use case.
+ * Implements methods declared by BasePreProcess and anything else needed
+ * to populate input tensors ready for inference.
+ */
+ class AdPreProcess : public BasePreProcess {
+
+ public:
+ /**
+ * @brief Constructor for AdPreProcess class objects
+ * @param[in] inputTensor input tensor pointer from the tensor arena.
+ * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length
+ * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride
+ * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used.
+ */
+ explicit AdPreProcess(TfLiteTensor* inputTensor,
+ uint32_t melSpectrogramFrameLen,
+ uint32_t melSpectrogramFrameStride,
+ float adModelTrainingMean);
+
+ ~AdPreProcess() = default;
+
+ /**
+ * @brief Function to invoke pre-processing and populate the input vector
+ * @param input pointer to input data. For anomaly detection, this is the pointer to
+ * the audio data.
+ * @param inputSize Size of the data being passed in for pre-processing.
+ * @return True if successful, false otherwise.
+ */
+ bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ /**
+ * @brief Getter function for audio window size computed when constructing
+ * the class object.
+ * @return Audio window size as 32 bit unsigned integer.
+ */
+ uint32_t GetAudioWindowSize();
+
+ /**
+ * @brief Getter function for audio window stride computed when constructing
+ * the class object.
+ * @return Audio window stride as 32 bit unsigned integer.
+ */
+ uint32_t GetAudioDataStride();
+
+ /**
+ * @brief Setter function for current audio index. This is only used for evaluating
+ * if previously computed features can be re-used from cache.
+ */
+ void SetAudioWindowIndex(uint32_t idx);
+
+ private:
+ bool m_validInstance{false}; /**< Indicates the current object is valid. */
+ uint32_t m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */
+ uint32_t m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */
+ uint8_t m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */
+ uint32_t m_numMelSpecVectorsInAudioStride{}; /**< Number of frames to move across the audio. */
+ uint32_t m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */
+ uint32_t m_audioDataStride{}; /**< Audio window stride computed. */
+ uint32_t m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */
+ uint32_t m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */
+
+ audio::SlidingWindow<const int16_t> m_melWindowSlider; /**< Internal MEL spectrogram window slider */
+ audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */
+ std::function<void
+ (std::vector<int16_t>&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */
+ };
+
+ class AdPostProcess : public BasePostProcess {
+ public:
+ /**
+ * @brief Constructor for AdPostProcess object.
+ * @param[in] outputTensor Output tensor pointer.
+ */
+ explicit AdPostProcess(TfLiteTensor* outputTensor);
+
+ ~AdPostProcess() = default;
+
+ /**
+ * @brief Function to do the post-processing on the output tensor.
+ * @return True if successful, false otherwise.
+ */
+ bool DoPostProcess() override;
+
+ /**
+ * @brief Getter function for an element from the de-quantised output vector.
+ * @param index Index of the element to be retrieved.
+ * @return index represented as a 32 bit floating point number.
+ */
+ float GetOutputValue(uint32_t index);
+
+ private:
+ TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */
+ std::vector<float> m_dequantizedOutputVec{}; /**< Internal output vector */
+
+ /**
+ * @brief De-quantizes and flattens the output tensor into a vector.
+ * @tparam T template parameter to indicate data type.
+ * @return True if successful, false otherwise.
+ */
+ template<typename T>
+ bool Dequantize()
+ {
+ TfLiteTensor* tensor = this->m_outputTensor;
+ if (tensor == nullptr) {
+ printf_err("Invalid output tensor.\n");
+ return false;
+ }
+ T* tensorData = tflite::GetTensorData<T>(tensor);
+
+ uint32_t totalOutputSize = 1;
+ for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
+ totalOutputSize *= tensor->dims->data[inputDim];
+ }
+
+ /* For getting the floating point values, we need quantization parameters */
+ QuantParams quantParams = GetTensorQuantParams(tensor);
+
+ this->m_dequantizedOutputVec = std::vector<float>(totalOutputSize, 0);
+
+ for (size_t i = 0; i < totalOutputSize; ++i) {
+ this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
+ }
+
+ return true;
+ }
+ };
+
+ /* Templated instances available: */
+ template bool AdPostProcess::Dequantize<int8_t>();
+
+ /**
+ * @brief Generic feature calculator factory.
+ *
+ * Returns lambda function to compute features using features cache.
+ * Real features math is done by a lambda function provided as a parameter.
+ * Features are written to input tensor memory.
+ *
+ * @tparam T feature vector type.
+ * @param inputTensor model input tensor pointer.
+ * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
+ * @param compute features calculator function.
+ * @return lambda function to compute features.
+ */
+ template<class T>
+ std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+ FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+ std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+ {
+ /* Feature cache to be captured by lambda function*/
+ static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+
+ return [=](std::vector<int16_t>& audioDataWindow,
+ size_t index,
+ bool useCache,
+ size_t featuresOverlapIndex,
+ size_t resizeScale)
+ {
+ T* tensorData = tflite::GetTensorData<T>(inputTensor);
+ std::vector<T> features;
+
+ /* Reuse features from cache if cache is ready and sliding windows overlap.
+ * Overlap is in the beginning of sliding window with a size of a feature cache. */
+ if (useCache && index < featureCache.size()) {
+ features = std::move(featureCache[index]);
+ } else {
+ features = std::move(compute(audioDataWindow));
+ }
+ auto size = features.size() / resizeScale;
+ auto sizeBytes = sizeof(T);
+
+ /* Input should be transposed and "resized" by skipping elements. */
+ for (size_t outIndex = 0; outIndex < size; outIndex++) {
+ std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
+ }
+
+ /* Start renewing cache as soon iteration goes out of the windows overlap. */
+ if (index >= featuresOverlapIndex / resizeScale) {
+ featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
+ }
+ };
+ }
+
+ template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
+ FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+
+ template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+ FeatureCalc<float>(TfLiteTensor *inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+
+ std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
+ GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
+ TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ float trainingMean);
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* AD_PROCESSING_HPP */
diff --git a/source/use_case/ad/src/AdPostProcessing.cc b/source/use_case/ad/src/AdPostProcessing.cc
deleted file mode 100644
index c461875..0000000
--- a/source/use_case/ad/src/AdPostProcessing.cc
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "AdPostProcessing.hpp"
-#include "log_macros.h"
-
-#include <numeric>
-#include <cmath>
-#include <string>
-
-namespace arm {
-namespace app {
-
- template<typename T>
- std::vector<float> Dequantize(TfLiteTensor* tensor) {
-
- if (tensor == nullptr) {
- printf_err("Tensor is null pointer can not dequantize.\n");
- return std::vector<float>();
- }
- T* tensorData = tflite::GetTensorData<T>(tensor);
-
- uint32_t totalOutputSize = 1;
- for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
- totalOutputSize *= tensor->dims->data[inputDim];
- }
-
- /* For getting the floating point values, we need quantization parameters */
- QuantParams quantParams = GetTensorQuantParams(tensor);
-
- std::vector<float> dequantizedOutput(totalOutputSize);
-
- for (size_t i = 0; i < totalOutputSize; ++i) {
- dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
- }
-
- return dequantizedOutput;
- }
-
- void Softmax(std::vector<float>& inputVector) {
- auto start = inputVector.begin();
- auto end = inputVector.end();
-
- /* Fix for numerical stability and apply exp. */
- float maxValue = *std::max_element(start, end);
- for (auto it = start; it!=end; ++it) {
- *it = std::exp((*it) - maxValue);
- }
-
- float sumExp = std::accumulate(start, end, 0.0f);
-
- for (auto it = start; it!=end; ++it) {
- *it = (*it)/sumExp;
- }
- }
-
- int8_t OutputIndexFromFileName(std::string wavFileName) {
- /* Filename is assumed in the form machine_id_00.wav */
- std::string delimiter = "_"; /* First character used to split the file name up. */
- size_t delimiterStart;
- std::string subString;
- size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */
-
- for (size_t i = 0; i < machineIdxInString; ++i) {
- delimiterStart = wavFileName.find(delimiter);
- subString = wavFileName.substr(0, delimiterStart);
- wavFileName.erase(0, delimiterStart + delimiter.length());
- }
-
- /* At this point substring should be 00.wav */
- delimiter = "."; /* Second character used to split the file name up. */
- delimiterStart = subString.find(delimiter);
- subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
-
- auto is_number = [](const std::string& str) -> bool
- {
- std::string::const_iterator it = str.begin();
- while (it != str.end() && std::isdigit(*it)) ++it;
- return !str.empty() && it == str.end();
- };
-
- const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
-
- /* Return corresponding index in the output vector. */
- if (machineIdx == 0) {
- return 0;
- } else if (machineIdx == 2) {
- return 1;
- } else if (machineIdx == 4) {
- return 2;
- } else if (machineIdx == 6) {
- return 3;
- } else {
- printf_err("%d is an invalid machine index \n", machineIdx);
- return -1;
- }
- }
-
- template std::vector<float> Dequantize<uint8_t>(TfLiteTensor* tensor);
- template std::vector<float> Dequantize<int8_t>(TfLiteTensor* tensor);
-} /* namespace app */
-} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/ad/src/AdProcessing.cc b/source/use_case/ad/src/AdProcessing.cc
new file mode 100644
index 0000000..a33131c
--- /dev/null
+++ b/source/use_case/ad/src/AdProcessing.cc
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "AdProcessing.hpp"
+
+#include "AdModel.hpp"
+
+namespace arm {
+namespace app {
+
+AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor,
+ uint32_t melSpectrogramFrameLen,
+ uint32_t melSpectrogramFrameStride,
+ float adModelTrainingMean):
+ m_validInstance{false},
+ m_melSpectrogramFrameLen{melSpectrogramFrameLen},
+ m_melSpectrogramFrameStride{melSpectrogramFrameStride},
+ /**< Model is trained on features downsampled 2x */
+ m_inputResizeScale{2},
+ /**< We are choosing to move by 20 frames across the audio for each inference. */
+ m_numMelSpecVectorsInAudioStride{20},
+ m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride},
+ m_melSpec{melSpectrogramFrameLen}
+{
+ if (!inputTensor) {
+ printf_err("Invalid input tensor provided to pre-process\n");
+ return;
+ }
+
+ TfLiteIntArray* inputShape = inputTensor->dims;
+
+ if (!inputShape) {
+ printf_err("Invalid input tensor dims\n");
+ return;
+ }
+
+ const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx];
+ const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx];
+
+ /* Deduce the data length required for 1 inference from the network parameters. */
+ this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) *
+ melSpectrogramFrameStride) +
+ melSpectrogramFrameLen;
+ this->m_numReusedFeatureVectors = kNumRows -
+ (this->m_numMelSpecVectorsInAudioStride /
+ this->m_inputResizeScale);
+ this->m_melSpec.Init();
+
+ /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
+ * "resizing" done here by multiplying stride by resize scale. */
+ this->m_melWindowSlider = audio::SlidingWindow<const int16_t>(
+ nullptr, /* to be populated later. */
+ this->m_audioDataWindowSize,
+ melSpectrogramFrameLen,
+ melSpectrogramFrameStride * this->m_inputResizeScale);
+
+ /* Construct feature calculation function. */
+ this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor,
+ this->m_numReusedFeatureVectors,
+ adModelTrainingMean);
+ this->m_validInstance = true;
+}
+
+bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize)
+{
+ /* Check that we have a valid instance. */
+ if (!this->m_validInstance) {
+ printf_err("Invalid pre-processor instance\n");
+ return false;
+ }
+
+ /* We expect that we can traverse the size with which the MEL spectrogram
+ * sliding window was initialised with. */
+ if (!input || inputSize < this->m_audioDataWindowSize) {
+ printf_err("Invalid input provided for pre-processing\n");
+ return false;
+ }
+
+ /* We moved to the next window - set the features sliding to the new address. */
+ this->m_melWindowSlider.Reset(static_cast<const int16_t*>(input));
+
+ /* The first window does not have cache ready. */
+ const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0;
+
+ /* Start calculating features inside one audio sliding window. */
+ while (this->m_melWindowSlider.HasNext()) {
+ const int16_t* melSpecWindow = this->m_melWindowSlider.Next();
+ std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(
+ melSpecWindow,
+ melSpecWindow + this->m_melSpectrogramFrameLen);
+
+ /* Compute features for this window and write them to input tensor. */
+ this->m_featureCalc(melSpecAudioData,
+ this->m_melWindowSlider.Index(),
+ useCache,
+ this->m_numMelSpecVectorsInAudioStride,
+ this->m_inputResizeScale);
+ }
+
+ return true;
+}
+
+uint32_t AdPreProcess::GetAudioWindowSize()
+{
+ return this->m_audioDataWindowSize;
+}
+
+uint32_t AdPreProcess::GetAudioDataStride()
+{
+ return this->m_audioDataStride;
+}
+
+void AdPreProcess::SetAudioWindowIndex(uint32_t idx)
+{
+ this->m_audioWindowIndex = idx;
+}
+
+AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) :
+ m_outputTensor {outputTensor}
+{}
+
+bool AdPostProcess::DoPostProcess()
+{
+ switch (this->m_outputTensor->type) {
+ case kTfLiteInt8:
+ this->Dequantize<int8_t>();
+ break;
+ default:
+ printf_err("Unsupported tensor type");
+ return false;
+ }
+
+ math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec);
+ return true;
+}
+
+float AdPostProcess::GetOutputValue(uint32_t index)
+{
+ if (index < this->m_dequantizedOutputVec.size()) {
+ return this->m_dequantizedOutputVec[index];
+ }
+ printf_err("Invalid index for output\n");
+ return 0.0;
+}
+
+std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
+GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
+ TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ float trainingMean)
+{
+ std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
+
+ TfLiteQuantization quant = inputTensor->quantization;
+
+ if (kTfLiteAffineQuantization == quant.type) {
+
+ auto* quantParams = static_cast<TfLiteAffineQuantization*>(quant.params);
+ const float quantScale = quantParams->scale->data[0];
+ const int quantOffset = quantParams->zero_point->data[0];
+
+ switch (inputTensor->type) {
+ case kTfLiteInt8: {
+ melSpecFeatureCalc = FeatureCalc<int8_t>(
+ inputTensor,
+ cacheSize,
+ [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
+ return melSpec.MelSpecComputeQuant<int8_t>(
+ audioDataWindow,
+ quantScale,
+ quantOffset,
+ trainingMean);
+ }
+ );
+ break;
+ }
+ default:
+ printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
+ }
+ } else {
+ melSpecFeatureCalc = FeatureCalc<float>(
+ inputTensor,
+ cacheSize,
+ [=, &melSpec](
+ std::vector<int16_t>& audioDataWindow) {
+ return melSpec.ComputeMelSpec(
+ audioDataWindow,
+ trainingMean);
+ });
+ }
+ return melSpecFeatureCalc;
+}
+
+} /* namespace app */
+} /* namespace arm */
diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc
index 23d1e51..140359b 100644
--- a/source/use_case/ad/src/MainLoop.cc
+++ b/source/use_case/ad/src/MainLoop.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "hal.h" /* Brings in platform definitions */
#include "InputFiles.hpp" /* For input data */
#include "AdModel.hpp" /* Model class for running inference */
#include "UseCaseCommonUtils.hpp" /* Utils functions */
@@ -63,8 +62,8 @@ void main_loop()
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
- caseContext.Set<int>("frameLength", g_FrameLength);
- caseContext.Set<int>("frameStride", g_FrameStride);
+ caseContext.Set<uint32_t>("frameLength", g_FrameLength);
+ caseContext.Set<uint32_t>("frameStride", g_FrameStride);
caseContext.Set<float>("scoreThreshold", g_ScoreThreshold);
caseContext.Set<float>("trainingMean", g_TrainingMean);
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc
index 5585f36..0179d6b 100644
--- a/source/use_case/ad/src/UseCaseHandler.cc
+++ b/source/use_case/ad/src/UseCaseHandler.cc
@@ -24,8 +24,8 @@
#include "AudioUtils.hpp"
#include "ImageUtils.hpp"
#include "UseCaseCommonUtils.hpp"
-#include "AdPostProcessing.hpp"
#include "log_macros.h"
+#include "AdProcessing.hpp"
namespace arm {
namespace app {
@@ -39,32 +39,17 @@ namespace app {
**/
static bool PresentInferenceResult(float result, float threshold);
- /**
- * @brief Returns a function to perform feature calculation and populates input tensor data with
- * MelSpe data.
- *
- * Input tensor data type check is performed to choose correct MFCC feature data type.
- * If tensor has an integer data type then original features are quantised.
- *
- * Warning: mfcc calculator provided as input must have the same life scope as returned function.
- *
- * @param[in] melSpec MFCC feature calculator.
- * @param[in,out] inputTensor Input tensor pointer to store calculated features.
- * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
- * @param[in] trainingMean Training mean.
- * @return function function to be called providing audio sample and sliding window index.
- */
- static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
- GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
- TfLiteTensor* inputTensor,
- size_t cacheSize,
- float trainingMean);
-
- /* Vibration classification handler */
+ /** @brief Given a wav file name return AD model output index.
+ * @param[in] wavFileName Audio WAV filename.
+ * File name should be in format anything_goes_XX_here.wav
+ * where XX is the machine ID e.g. 00, 02, 04 or 06
+ * @return AD model output index as 8 bit integer.
+ **/
+ static int8_t OutputIndexFromFileName(std::string wavFileName);
+
+ /* Anomaly Detection inference handler */
bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
- auto& profiler = ctx.Get<Profiler&>("profiler");
-
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
@@ -81,8 +66,9 @@ namespace app {
return false;
}
- const auto frameLength = ctx.Get<int>("frameLength");
- const auto frameStride = ctx.Get<int>("frameStride");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
+ const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength");
+ const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride");
const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
const auto trainingMean = ctx.Get<float>("trainingMean");
auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
@@ -95,21 +81,13 @@ namespace app {
return false;
}
- TfLiteIntArray* inputShape = model.GetInputShape(0);
- const uint32_t kNumRows = inputShape->data[1];
- const uint32_t kNumCols = inputShape->data[2];
+ AdPreProcess preProcess{
+ inputTensor,
+ melSpecFrameLength,
+ melSpecFrameStride,
+ trainingMean};
- audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength);
- melSpec.Init();
-
- /* Deduce the data length required for 1 inference from the network parameters. */
- const uint8_t inputResizeScale = 2;
- const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength;
-
- /* We are choosing to move by 20 frames across the audio for each inference. */
- const uint8_t nMelSpecVectorsInAudioStride = 20;
-
- auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride;
+ AdPostProcess postProcess{outputTensor};
do {
hal_lcd_clear(COLOR_BLACK);
@@ -122,29 +100,12 @@ namespace app {
return false;
}
- /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
- * "resizing" done here by multiplying stride by resize scale. */
- auto audioMelSpecWindowSlider = audio::SlidingWindow<const int16_t>(
- get_audio_array(currentIndex),
- audioDataWindowSize, frameLength,
- frameStride * inputResizeScale);
-
/* Creating a sliding window through the whole audio clip. */
auto audioDataSlider = audio::SlidingWindow<const int16_t>(
- get_audio_array(currentIndex),
- get_audio_array_size(currentIndex),
- audioDataWindowSize, audioDataStride);
-
- /* Calculate number of the feature vectors in the window overlap region taking into account resizing.
- * These feature vectors will be reused.*/
- auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale);
-
- /* Construct feature calculation function. */
- auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor,
- numberOfReusedFeatureVectors, trainingMean);
- if (!melSpecFeatureCalc){
- return false;
- }
+ get_audio_array(currentIndex),
+ get_audio_array_size(currentIndex),
+ preProcess.GetAudioWindowSize(),
+ preProcess.GetAudioDataStride());
/* Result is an averaged sum over inferences. */
float result = 0;
@@ -152,30 +113,18 @@ namespace app {
/* Display message on the LCD - inference running. */
std::string str_inf{"Running inference... "};
hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
- info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex));
+ str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+
+ info("Running inference on audio clip %" PRIu32 " => %s\n",
+ currentIndex, get_filename(currentIndex));
/* Start sliding through audio clip. */
while (audioDataSlider.HasNext()) {
- const int16_t *inferenceWindow = audioDataSlider.Next();
-
- /* We moved to the next window - set the features sliding to the new address. */
- audioMelSpecWindowSlider.Reset(inferenceWindow);
+ const int16_t* inferenceWindow = audioDataSlider.Next();
- /* The first window does not have cache ready. */
- bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
-
- /* Start calculating features inside one audio sliding window. */
- while (audioMelSpecWindowSlider.HasNext()) {
- const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next();
- std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(melSpecWindow,
- melSpecWindow + frameLength);
-
- /* Compute features for this window and write them to input tensor. */
- melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(),
- useCache, nMelSpecVectorsInAudioStride, inputResizeScale);
- }
+ preProcess.SetAudioWindowIndex(audioDataSlider.Index());
+ preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize());
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
audioDataSlider.TotalStrides() + 1);
@@ -185,13 +134,11 @@ namespace app {
return false;
}
- /* Use the negative softmax score of the corresponding index as the outlier score */
- std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor);
- Softmax(dequantOutput);
- result += -dequantOutput[machineOutputIndex];
+ postProcess.DoPostProcess();
+ result += 0 - postProcess.GetOutputValue(machineOutputIndex);
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(outputTensor);
+ DumpTensor(outputTensor);
#endif /* VERIFY_TEST_OUTPUT */
} /* while (audioDataSlider.HasNext()) */
@@ -218,7 +165,6 @@ namespace app {
return true;
}
-
static bool PresentInferenceResult(float result, float threshold)
{
constexpr uint32_t dataPsnTxtStartX1 = 20;
@@ -251,148 +197,47 @@ namespace app {
return true;
}
- /**
- * @brief Generic feature calculator factory.
- *
- * Returns lambda function to compute features using features cache.
- * Real features math is done by a lambda function provided as a parameter.
- * Features are written to input tensor memory.
- *
- * @tparam T feature vector type.
- * @param inputTensor model input tensor pointer.
- * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
- * @param compute features calculator function.
- * @return lambda function to compute features.
- */
- template<class T>
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
- FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
- std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+ static int8_t OutputIndexFromFileName(std::string wavFileName)
{
- /* Feature cache to be captured by lambda function*/
- static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
-
- return [=](std::vector<int16_t>& audioDataWindow,
- size_t index,
- bool useCache,
- size_t featuresOverlapIndex,
- size_t resizeScale)
- {
- T *tensorData = tflite::GetTensorData<T>(inputTensor);
- std::vector<T> features;
-
- /* Reuse features from cache if cache is ready and sliding windows overlap.
- * Overlap is in the beginning of sliding window with a size of a feature cache. */
- if (useCache && index < featureCache.size()) {
- features = std::move(featureCache[index]);
- } else {
- features = std::move(compute(audioDataWindow));
- }
- auto size = features.size() / resizeScale;
- auto sizeBytes = sizeof(T);
+ /* Filename is assumed in the form machine_id_00.wav */
+ std::string delimiter = "_"; /* First character used to split the file name up. */
+ size_t delimiterStart;
+ std::string subString;
+ size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */
+
+ for (size_t i = 0; i < machineIdxInString; ++i) {
+ delimiterStart = wavFileName.find(delimiter);
+ subString = wavFileName.substr(0, delimiterStart);
+ wavFileName.erase(0, delimiterStart + delimiter.length());
+ }
- /* Input should be transposed and "resized" by skipping elements. */
- for (size_t outIndex = 0; outIndex < size; outIndex++) {
- std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
- }
+ /* At this point substring should be 00.wav */
+ delimiter = "."; /* Second character used to split the file name up. */
+ delimiterStart = subString.find(delimiter);
+ subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
- /* Start renewing cache as soon iteration goes out of the windows overlap. */
- if (index >= featuresOverlapIndex / resizeScale) {
- featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
- }
+ auto is_number = [](const std::string& str) -> bool
+ {
+ std::string::const_iterator it = str.begin();
+ while (it != str.end() && std::isdigit(*it)) ++it;
+ return !str.empty() && it == str.end();
};
- }
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
- FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
- FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<uint8_t> (std::vector<int16_t>&)> compute);
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
- FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int16_t> (std::vector<int16_t>&)> compute);
-
- template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
- FeatureCalc<float>(TfLiteTensor *inputTensor,
- size_t cacheSize,
- std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
- static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
- GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean)
- {
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
-
- TfLiteQuantization quant = inputTensor->quantization;
-
- if (kTfLiteAffineQuantization == quant.type) {
-
- auto *quantParams = (TfLiteAffineQuantization *) quant.params;
- const float quantScale = quantParams->scale->data[0];
- const int quantOffset = quantParams->zero_point->data[0];
-
- switch (inputTensor->type) {
- case kTfLiteInt8: {
- melSpecFeatureCalc = FeatureCalc<int8_t>(inputTensor,
- cacheSize,
- [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
- return melSpec.MelSpecComputeQuant<int8_t>(
- audioDataWindow,
- quantScale,
- quantOffset,
- trainingMean);
- }
- );
- break;
- }
- case kTfLiteUInt8: {
- melSpecFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
- cacheSize,
- [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
- return melSpec.MelSpecComputeQuant<uint8_t>(
- audioDataWindow,
- quantScale,
- quantOffset,
- trainingMean);
- }
- );
- break;
- }
- case kTfLiteInt16: {
- melSpecFeatureCalc = FeatureCalc<int16_t>(inputTensor,
- cacheSize,
- [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
- return melSpec.MelSpecComputeQuant<int16_t>(
- audioDataWindow,
- quantScale,
- quantOffset,
- trainingMean);
- }
- );
- break;
- }
- default:
- printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
- }
-
+ const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
+
+ /* Return corresponding index in the output vector. */
+ if (machineIdx == 0) {
+ return 0;
+ } else if (machineIdx == 2) {
+ return 1;
+ } else if (machineIdx == 4) {
+ return 2;
+ } else if (machineIdx == 6) {
+ return 3;
} else {
- melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc<float>(inputTensor,
- cacheSize,
- [=, &melSpec](
- std::vector<int16_t>& audioDataWindow) {
- return melSpec.ComputeMelSpec(
- audioDataWindow,
- trainingMean);
- });
+ printf_err("%d is an invalid machine index \n", machineIdx);
+ return -1;
}
- return melSpecFeatureCalc;
}
} /* namespace app */
diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp
index 0078e44..bec70ab 100644
--- a/source/use_case/asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/asr/include/Wav2LetterModel.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.rved.
+ * Copyright (c) 2021 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/source/use_case/kws_asr/include/KwsProcessing.hpp b/source/use_case/kws_asr/include/KwsProcessing.hpp
new file mode 100644
index 0000000..d3de3b3
--- /dev/null
+++ b/source/use_case/kws_asr/include/KwsProcessing.hpp
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef KWS_PROCESSING_HPP
+#define KWS_PROCESSING_HPP
+
+#include <AudioUtils.hpp>
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "Classifier.hpp"
+#include "MicroNetKwsMfcc.hpp"
+
+#include <functional>
+
+namespace arm {
+namespace app {
+
+ /**
+ * @brief Pre-processing class for Keyword Spotting use case.
+ * Implements methods declared by BasePreProcess and anything else needed
+ * to populate input tensors ready for inference.
+ */
+ class KwsPreProcess : public BasePreProcess {
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] numFeatures How many MFCC features to use.
+ * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
+ * for an inference.
+ * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
+ * sliding a window through the audio sample.
+ * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
+ **/
+ explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames,
+ int mfccFrameLength, int mfccFrameStride);
+
+ /**
+ * @brief Should perform pre-processing of 'raw' input audio data and load it into
+ * TFLite Micro input tensors ready for inference.
+ * @param[in] input Pointer to the data that pre-processing will work on.
+ * @param[in] inputSize Size of the input data.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */
+ size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */
+ size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
+
+ private:
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
+ const int m_mfccFrameLength;
+ const int m_mfccFrameStride;
+ const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */
+
+ audio::MicroNetKwsMFCC m_mfcc;
+ audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
+ size_t m_numMfccVectorsInAudioStride;
+ size_t m_numReusedMfccVectors;
+ std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
+
+ /**
+ * @brief Returns a function to perform feature calculation and populates input tensor data with
+ * MFCC data.
+ *
+ * Input tensor data type check is performed to choose correct MFCC feature data type.
+ * If tensor has an integer data type then original features are quantised.
+ *
+ * Warning: MFCC calculator provided as input must have the same life scope as returned function.
+ *
+ * @param[in] mfcc MFCC feature calculator.
+ * @param[in,out] inputTensor Input tensor pointer to store calculated features.
+ * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
+ * @return Function to be called providing audio sample and sliding window index.
+ */
+ std::function<void (std::vector<int16_t>&, int, bool, size_t)>
+ GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
+ TfLiteTensor* inputTensor,
+ size_t cacheSize);
+
+ template<class T>
+ std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
+ FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+ std::function<std::vector<T> (std::vector<int16_t>& )> compute);
+ };
+
+ /**
+ * @brief Post-processing class for Keyword Spotting use case.
+ * Implements methods declared by BasePostProcess and anything else needed
+ * to populate result vector.
+ */
+ class KwsPostProcess : public BasePostProcess {
+
+ private:
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ Classifier& m_kwsClassifier; /* KWS Classifier object. */
+ const std::vector<std::string>& m_labels; /* KWS Labels. */
+ std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] results Vector of classification results to store decoded outputs.
+ **/
+ KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
+ const std::vector<std::string>& labels,
+ std::vector<ClassificationResult>& results);
+
+ /**
+ * @brief Should perform post-processing of the result of inference then
+ * populate KWS result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+ };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* KWS_PROCESSING_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
index 43bd390..af6ba5f 100644
--- a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
+++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +24,7 @@ namespace app {
namespace audio {
/* Class to provide MicroNet specific MFCC calculation requirements. */
- class MicroNetMFCC : public MFCC {
+ class MicroNetKwsMFCC : public MFCC {
public:
static constexpr uint32_t ms_defaultSamplingFreq = 16000;
@@ -34,14 +34,14 @@ namespace audio {
static constexpr bool ms_defaultUseHtkMethod = true;
- explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen)
+ explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen)
: MFCC(MfccParams(
ms_defaultSamplingFreq, ms_defaultNumFbankBins,
ms_defaultMelLoFreq, ms_defaultMelHiFreq,
numFeats, frameLen, ms_defaultUseHtkMethod))
{}
- MicroNetMFCC() = delete;
- ~MicroNetMFCC() = default;
+ MicroNetKwsMFCC() = delete;
+ ~MicroNetKwsMFCC() = default;
};
} /* namespace audio */
diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp
index 7c327b3..0e1adc5 100644
--- a/source/use_case/kws_asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterModel.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -34,14 +34,18 @@ namespace arm {
namespace app {
class Wav2LetterModel : public Model {
-
+
public:
/* Indices for the expected model - based on input and output tensor shapes */
- static constexpr uint32_t ms_inputRowsIdx = 1;
- static constexpr uint32_t ms_inputColsIdx = 2;
+ static constexpr uint32_t ms_inputRowsIdx = 1;
+ static constexpr uint32_t ms_inputColsIdx = 2;
static constexpr uint32_t ms_outputRowsIdx = 2;
static constexpr uint32_t ms_outputColsIdx = 3;
+ /* Model specific constants. */
+ static constexpr uint32_t ms_blankTokenIdx = 28;
+ static constexpr uint32_t ms_numMfccFeatures = 13;
+
protected:
/** @brief Gets the reference to op resolver interface class. */
const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
index 029a641..d1bc9a2 100644
--- a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,88 +14,95 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_WAV2LET_POSTPROC_HPP
-#define KWS_ASR_WAV2LET_POSTPROC_HPP
+#ifndef KWS_ASR_WAV2LETTER_POSTPROCESS_HPP
+#define KWS_ASR_WAV2LETTER_POSTPROCESS_HPP
-#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers */
+#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */
+#include "BaseProcessing.hpp"
+#include "AsrClassifier.hpp"
+#include "AsrResult.hpp"
+#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/**
* @brief Helper class to manage tensor post-processing for "wav2letter"
* output.
*/
- class Postprocess {
+ class AsrPostProcess : public BasePostProcess {
public:
+ bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */
+
/**
- * @brief Constructor
- * @param[in] contextLen Left and right context length for
- * output tensor.
- * @param[in] innerLen This is the length of the section
- * between left and right context.
- * @param[in] blankTokenIdx Blank token index.
+ * @brief Constructor
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] result Vector of classification results to store decoded outputs.
+ * @param[in] outputContextLen Left/right context length for output tensor.
+ * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes.
+ * @param[in] reductionAxis The axis that the logits of each time step is on.
**/
- Postprocess(uint32_t contextLen,
- uint32_t innerLen,
- uint32_t blankTokenIdx);
-
- Postprocess() = delete;
- ~Postprocess() = default;
+ AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
+ const std::vector<std::string>& labels, asr::ResultVec& result,
+ uint32_t outputContextLen,
+ uint32_t blankTokenIdx, uint32_t reductionAxis);
/**
- * @brief Erases the required part of the tensor based
- * on context lengths set up during initialisation
- * @param[in] tensor Pointer to the tensor
- * @param[in] axisIdx Index of the axis on which erase is
- * performed.
- * @param[in] lastIteration Flag to signal is this is the
- * last iteration in which case
- * the right context is preserved.
- * @return true if successful, false otherwise.
- */
- bool Invoke(TfLiteTensor* tensor,
- uint32_t axisIdx,
- bool lastIteration = false);
+ * @brief Should perform post-processing of the result of inference then
+ * populate ASR result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+
+ /** @brief Gets the output inner length for post-processing. */
+ static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen);
+
+ /** @brief Gets the output context length (left/right) for post-processing. */
+ static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen);
+
+ /** @brief Gets the number of feature vectors to be computed. */
+ static uint32_t GetNumFeatureVectors(const Model& model);
private:
- uint32_t m_contextLen; /* Lengths of left and right contexts. */
- uint32_t m_innerLen; /* Length of inner context. */
- uint32_t m_totalLen; /* Total length of the required axis. */
- uint32_t m_countIterations; /* Current number of iterations. */
- uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ AsrClassifier& m_classifier; /* ASR Classifier object. */
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ const std::vector<std::string>& m_labels; /* ASR Labels. */
+ asr::ResultVec & m_results; /* Results vector for a single inference. */
+ uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */
+ uint32_t m_outputInnerLen; /* Length of output inner context. */
+ uint32_t m_totalLen; /* Total length of the required axis. */
+ uint32_t m_countIterations; /* Current number of iterations. */
+ uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */
+
/**
- * @brief Checks if the tensor and axis index are valid
- * inputs to the object - based on how it has been
- * initialised.
- * @return true if valid, false otherwise.
+ * @brief Checks if the tensor and axis index are valid
+ * inputs to the object - based on how it has been initialised.
+ * @return true if valid, false otherwise.
*/
bool IsInputValid(TfLiteTensor* tensor,
- const uint32_t axisIdx) const;
+ uint32_t axisIdx) const;
/**
- * @brief Gets the tensor data element size in bytes based
- * on the tensor type.
- * @return Size in bytes, 0 if not supported.
+ * @brief Gets the tensor data element size in bytes based
+ * on the tensor type.
+ * @return Size in bytes, 0 if not supported.
*/
- uint32_t GetTensorElementSize(TfLiteTensor* tensor);
+ static uint32_t GetTensorElementSize(TfLiteTensor* tensor);
/**
- * @brief Erases sections from the data assuming row-wise
- * arrangement along the context axis.
- * @return true if successful, false otherwise.
+ * @brief Erases sections from the data assuming row-wise
+ * arrangement along the context axis.
+ * @return true if successful, false otherwise.
*/
bool EraseSectionsRowWise(uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration);
-
+ uint32_t strideSzBytes,
+ bool lastIteration);
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_ASR_WAV2LET_POSTPROC_HPP */ \ No newline at end of file
+#endif /* KWS_ASR_WAV2LETTER_POSTPROCESS_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
index 3609c49..1224c23 100644
--- a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,56 +14,51 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_WAV2LET_PREPROC_HPP
-#define KWS_ASR_WAV2LET_PREPROC_HPP
+#ifndef KWS_ASR_WAV2LETTER_PREPROCESS_HPP
+#define KWS_ASR_WAV2LETTER_PREPROCESS_HPP
#include "Wav2LetterModel.hpp"
#include "Wav2LetterMfcc.hpp"
#include "AudioUtils.hpp"
#include "DataStructures.hpp"
+#include "BaseProcessing.hpp"
#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/* Class to facilitate pre-processing calculation for Wav2Letter model
* for ASR. */
- using AudioWindow = SlidingWindow <const int16_t>;
+ using AudioWindow = audio::SlidingWindow<const int16_t>;
- class Preprocess {
+ class AsrPreProcess : public BasePreProcess {
public:
/**
- * @brief Constructor
- * @param[in] numMfccFeatures Number of MFCC features per window.
- * @param[in] windowLen Number of elements in a window.
- * @param[in] windowStride Stride (in number of elements) for
- * moving the window.
- * @param[in] numMfccVectors Number of MFCC vectors per window.
- */
- Preprocess(
- uint32_t numMfccFeatures,
- uint32_t windowLen,
- uint32_t windowStride,
- uint32_t numMfccVectors);
- Preprocess() = delete;
- ~Preprocess() = default;
+ * @brief Constructor.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] numMfccFeatures Number of MFCC features per window.
+ * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
+ * for an inference.
+ * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window.
+ * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window.
+ */
+ AsrPreProcess(TfLiteTensor* inputTensor,
+ uint32_t numMfccFeatures,
+ uint32_t numFeatureFrames,
+ uint32_t mfccWindowLen,
+ uint32_t mfccWindowStride);
/**
* @brief Calculates the features required from audio data. This
* includes MFCC, first and second order deltas,
* normalisation and finally, quantisation. The tensor is
- * populated with feature from a given window placed along
+ * populated with features from a given window placed along
* in a single row.
* @param[in] audioData Pointer to the first element of audio data.
* @param[in] audioDataLen Number of elements in the audio data.
- * @param[in] tensor Tensor to be populated.
* @return true if successful, false in case of error.
*/
- bool Invoke(const int16_t * audioData,
- uint32_t audioDataLen,
- TfLiteTensor * tensor);
+ bool DoPreProcess(const void* audioData, size_t audioDataLen) override;
protected:
/**
@@ -73,49 +68,32 @@ namespace asr {
* @param[in] mfcc MFCC buffers.
* @param[out] delta1 Result of the first diff computation.
* @param[out] delta2 Result of the second diff computation.
- *
- * @return true if successful, false otherwise.
+ * @return true if successful, false otherwise.
*/
static bool ComputeDeltas(Array2d<float>& mfcc,
Array2d<float>& delta1,
Array2d<float>& delta2);
/**
- * @brief Given a 2D vector of floats, computes the mean.
- * @param[in] vec Vector of vector of floats.
- * @return Mean value.
- */
- static float GetMean(Array2d<float>& vec);
-
- /**
- * @brief Given a 2D vector of floats, computes the stddev.
- * @param[in] vec Vector of vector of floats.
- * @param[in] mean Mean value of the vector passed in.
- * @return stddev value.
- */
- static float GetStdDev(Array2d<float>& vec,
- const float mean);
-
- /**
- * @brief Given a 2D vector of floats, normalises it using
- * the mean and the stddev
+ * @brief Given a 2D vector of floats, rescale it to have mean of 0 and
+ * standard deviation of 1.
* @param[in,out] vec Vector of vector of floats.
*/
- static void NormaliseVec(Array2d<float>& vec);
+ static void StandardizeVecF32(Array2d<float>& vec);
/**
- * @brief Normalises the MFCC and delta buffers.
+ * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1.
*/
- void Normalise();
+ void Standarize();
/**
* @brief Given the quantisation and data type limits, computes
* the quantised values of a floating point input data.
- * @param[in] elem Element to be quantised.
- * @param[in] quantScale Scale.
- * @param[in] quantOffset Offset.
- * @param[in] minVal Numerical limit - minimum.
- * @param[in] maxVal Numerical limit - maximum.
+ * @param[in] elem Element to be quantised.
+ * @param[in] quantScale Scale.
+ * @param[in] quantOffset Offset.
+ * @param[in] minVal Numerical limit - minimum.
+ * @param[in] maxVal Numerical limit - maximum.
* @return Floating point quantised value.
*/
static float GetQuantElem(
@@ -133,44 +111,43 @@ namespace asr {
* this being the convolution speed up (as we can use
* contiguous memory). The output, however, requires the
* time axis to be in column major arrangement.
- * @param[in] outputBuf Pointer to the output buffer.
- * @param[in] outputBufSz Output buffer's size.
- * @param[in] quantScale Quantisation scale.
- * @param[in] quantOffset Quantisation offset.
+ * @param[in] outputBuf Pointer to the output buffer.
+ * @param[in] outputBufSz Output buffer's size.
+ * @param[in] quantScale Quantisation scale.
+ * @param[in] quantOffset Quantisation offset.
*/
template <typename T>
bool Quantise(
- T * outputBuf,
+ T* outputBuf,
const uint32_t outputBufSz,
const float quantScale,
const int quantOffset)
{
- /* Check the output size will for everything. */
+ /* Check the output size will fit everything. */
if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) {
printf_err("Tensor size too small for features\n");
return false;
}
/* Populate. */
- T * outputBufMfcc = outputBuf;
- T * outputBufD1 = outputBuf + this->m_numMfccFeats;
- T * outputBufD2 = outputBufD1 + this->m_numMfccFeats;
+ T* outputBufMfcc = outputBuf;
+ T* outputBufD1 = outputBuf + this->m_numMfccFeats;
+ T* outputBufD2 = outputBufD1 + this->m_numMfccFeats;
const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */
const float minVal = std::numeric_limits<T>::min();
const float maxVal = std::numeric_limits<T>::max();
- /* We need to do a transpose while copying and concatenating
- * the tensor. */
- for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) {
+ /* Need to transpose while copying and concatenating the tensor. */
+ for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) {
for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) {
- *outputBufMfcc++ = static_cast<T>(this->GetQuantElem(
+ *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_mfccBuf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD1++ = static_cast<T>(this->GetQuantElem(
+ *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_delta1Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD2++ = static_cast<T>(this->GetQuantElem(
+ *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_delta2Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
}
@@ -183,24 +160,23 @@ namespace asr {
}
private:
- Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
/* Actual buffers to be populated. */
- Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
- Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
- Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
+ Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
+ Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
+ Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
- uint32_t m_windowLen; /* Window length for MFCC. */
- uint32_t m_windowStride; /* Window stride len for MFCC. */
- uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
- uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */
- AudioWindow m_window; /* Sliding window. */
+ uint32_t m_mfccWindowLen; /* Window length for MFCC. */
+ uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */
+ uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
+ uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */
+ AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_ASR_WAV2LET_PREPROC_HPP */ \ No newline at end of file
+#endif /* KWS_ASR_WAV2LETTER_PREPROCESS_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/src/KwsProcessing.cc b/source/use_case/kws_asr/src/KwsProcessing.cc
new file mode 100644
index 0000000..328709d
--- /dev/null
+++ b/source/use_case/kws_asr/src/KwsProcessing.cc
@@ -0,0 +1,212 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "KwsProcessing.hpp"
+#include "ImageUtils.hpp"
+#include "log_macros.h"
+#include "MicroNetKwsModel.hpp"
+
+namespace arm {
+namespace app {
+
+ KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames,
+ int mfccFrameLength, int mfccFrameStride
+ ):
+ m_inputTensor{inputTensor},
+ m_mfccFrameLength{mfccFrameLength},
+ m_mfccFrameStride{mfccFrameStride},
+ m_numMfccFrames{numMfccFrames},
+ m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)}
+ {
+ this->m_mfcc.Init();
+
+ /* Deduce the data length required for 1 inference from the network parameters. */
+ this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride +
+ (this->m_mfccFrameLength - this->m_mfccFrameStride);
+
+ /* Creating an MFCC feature sliding window for the data required for 1 inference. */
+ this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(nullptr, this->m_audioDataWindowSize,
+ this->m_mfccFrameLength, this->m_mfccFrameStride);
+
+ /* For longer audio clips we choose to move by half the audio window size
+ * => for a 1 second window size there is an overlap of 0.5 seconds. */
+ this->m_audioDataStride = this->m_audioDataWindowSize / 2;
+
+ /* To have the previously calculated features re-usable, stride must be multiple
+ * of MFCC features window stride. Reduce stride through audio if needed. */
+ if (0 != this->m_audioDataStride % this->m_mfccFrameStride) {
+ this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride;
+ }
+
+ this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride;
+
+ /* Calculate number of the feature vectors in the window overlap region.
+ * These feature vectors will be reused.*/
+ this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1
+ - this->m_numMfccVectorsInAudioStride;
+
+ /* Construct feature calculation function. */
+ this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor,
+ this->m_numReusedMfccVectors);
+
+ if (!this->m_mfccFeatureCalculator) {
+ printf_err("Feature calculator not initialized.");
+ }
+ }
+
+ bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize)
+ {
+ UNUSED(inputSize);
+ if (data == nullptr) {
+ printf_err("Data pointer is null");
+ }
+
+ /* Set the features sliding window to the new address. */
+ auto input = static_cast<const int16_t*>(data);
+ this->m_mfccSlidingWindow.Reset(input);
+
+ /* Cache is only usable if we have more than 1 inference in an audio clip. */
+ bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0;
+
+ /* Use a sliding window to calculate MFCC features frame by frame. */
+ while (this->m_mfccSlidingWindow.HasNext()) {
+ const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
+
+ std::vector<int16_t> mfccFrameAudioData = std::vector<int16_t>(mfccWindow,
+ mfccWindow + this->m_mfccFrameLength);
+
+ /* Compute features for this window and write them to input tensor. */
+ this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(),
+ useCache, this->m_numMfccVectorsInAudioStride);
+ }
+
+ debug("Input tensor populated \n");
+
+ return true;
+ }
+
+ /**
+ * @brief Generic feature calculator factory.
+ *
+ * Returns lambda function to compute features using features cache.
+ * Real features math is done by a lambda function provided as a parameter.
+ * Features are written to input tensor memory.
+ *
+ * @tparam T Feature vector type.
+ * @param[in] inputTensor Model input tensor pointer.
+ * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap.
+ * @param[in] compute Features calculator function.
+ * @return Lambda function to compute features.
+ */
+ template<class T>
+ std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
+ KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+ std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+ {
+ /* Feature cache to be captured by lambda function. */
+ static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+
+ return [=](std::vector<int16_t>& audioDataWindow,
+ size_t index,
+ bool useCache,
+ size_t featuresOverlapIndex)
+ {
+ T* tensorData = tflite::GetTensorData<T>(inputTensor);
+ std::vector<T> features;
+
+ /* Reuse features from cache if cache is ready and sliding windows overlap.
+ * Overlap is in the beginning of sliding window with a size of a feature cache. */
+ if (useCache && index < featureCache.size()) {
+ features = std::move(featureCache[index]);
+ } else {
+ features = std::move(compute(audioDataWindow));
+ }
+ auto size = features.size();
+ auto sizeBytes = sizeof(T) * size;
+ std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
+
+ /* Start renewing cache as soon iteration goes out of the windows overlap. */
+ if (index >= featuresOverlapIndex) {
+ featureCache[index - featuresOverlapIndex] = std::move(features);
+ }
+ };
+ }
+
+ template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
+ KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+
+ template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
+ KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+
+
+ std::function<void (std::vector<int16_t>&, int, bool, size_t)>
+ KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+ {
+ std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
+
+ TfLiteQuantization quant = inputTensor->quantization;
+
+ if (kTfLiteAffineQuantization == quant.type) {
+ auto *quantParams = (TfLiteAffineQuantization *) quant.params;
+ const float quantScale = quantParams->scale->data[0];
+ const int quantOffset = quantParams->zero_point->data[0];
+
+ switch (inputTensor->type) {
+ case kTfLiteInt8: {
+ mfccFeatureCalc = this->FeatureCalc<int8_t>(inputTensor,
+ cacheSize,
+ [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
+ return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
+ quantScale,
+ quantOffset);
+ }
+ );
+ break;
+ }
+ default:
+ printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
+ }
+ } else {
+ mfccFeatureCalc = this->FeatureCalc<float>(inputTensor, cacheSize,
+ [&mfcc](std::vector<int16_t>& audioDataWindow) {
+ return mfcc.MfccCompute(audioDataWindow); }
+ );
+ }
+ return mfccFeatureCalc;
+ }
+
+ KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
+ const std::vector<std::string>& labels,
+ std::vector<ClassificationResult>& results)
+ :m_outputTensor{outputTensor},
+ m_kwsClassifier{classifier},
+ m_labels{labels},
+ m_results{results}
+ {}
+
+ bool KwsPostProcess::DoPostProcess()
+ {
+ return this->m_kwsClassifier.GetClassificationResults(
+ this->m_outputTensor, this->m_results,
+ this->m_labels, 1, true);
+ }
+
+} /* namespace app */
+} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc
index 5c1d0e0..f1d97a0 100644
--- a/source/use_case/kws_asr/src/MainLoop.cc
+++ b/source/use_case/kws_asr/src/MainLoop.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "hal.h" /* Brings in platform definitions. */
#include "InputFiles.hpp" /* For input images. */
#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */
#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
@@ -24,8 +23,6 @@
#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
-#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */
-#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */
#include "log_macros.h"
using KwsClassifier = arm::app::Classifier;
@@ -53,19 +50,8 @@ static void DisplayMenu()
fflush(stdout);
}
-/** @brief Gets the number of MFCC features for a single window. */
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model);
-
-/** @brief Gets the number of MFCC feature vectors to be computed. */
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model);
-
-/** @brief Gets the output context length (left and right) for post-processing. */
-static uint32_t GetOutputContextLen(const arm::app::Model& model,
- uint32_t inputCtxLen);
-
-/** @brief Gets the output inner length for post-processing. */
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- uint32_t outputCtxLen);
+/** @brief Verify input and output tensor are of certain min dimensions. */
+static bool VerifyTensorDimensions(const arm::app::Model& model);
void main_loop()
{
@@ -84,61 +70,46 @@ void main_loop()
if (!asrModel.Init(kwsModel.GetAllocator())) {
printf_err("Failed to initialise ASR model\n");
return;
+ } else if (!VerifyTensorDimensions(asrModel)) {
+ printf_err("Model's input or output dimension verification failed\n");
+ return;
}
- /* Initialise ASR pre-processing. */
- arm::app::audio::asr::Preprocess prep(
- GetNumMfccFeatures(asrModel),
- arm::app::asr::g_FrameLength,
- arm::app::asr::g_FrameStride,
- GetNumMfccFeatureVectors(asrModel));
-
- /* Initialise ASR post-processing. */
- const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen);
- const uint32_t blankTokenIdx = 28;
- arm::app::audio::asr::Postprocess postp(
- outputCtxLen,
- GetOutputInnerLen(asrModel, outputCtxLen),
- blankTokenIdx);
-
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
arm::app::Profiler profiler{"kws_asr"};
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
- caseContext.Set<arm::app::Model&>("kwsmodel", kwsModel);
- caseContext.Set<arm::app::Model&>("asrmodel", asrModel);
+ caseContext.Set<arm::app::Model&>("kwsModel", kwsModel);
+ caseContext.Set<arm::app::Model&>("asrModel", asrModel);
caseContext.Set<uint32_t>("clipIndex", 0);
caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */
- caseContext.Set<int>("kwsframeLength", arm::app::kws::g_FrameLength);
- caseContext.Set<int>("kwsframeStride", arm::app::kws::g_FrameStride);
- caseContext.Set<float>("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
+ caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength);
+ caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride);
+ caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc);
caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins);
- caseContext.Set<int>("asrframeLength", arm::app::asr::g_FrameLength);
- caseContext.Set<int>("asrframeStride", arm::app::asr::g_FrameStride);
- caseContext.Set<float>("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */
+ caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength);
+ caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride);
+ caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */
KwsClassifier kwsClassifier; /* Classifier wrapper object. */
arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */
- caseContext.Set<arm::app::Classifier&>("kwsclassifier", kwsClassifier);
- caseContext.Set<arm::app::AsrClassifier&>("asrclassifier", asrClassifier);
-
- caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep);
- caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp);
+ caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier);
+ caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier);
std::vector<std::string> asrLabels;
arm::app::asr::GetLabelsVector(asrLabels);
std::vector<std::string> kwsLabels;
arm::app::kws::GetLabelsVector(kwsLabels);
- caseContext.Set<const std::vector <std::string>&>("asrlabels", asrLabels);
- caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels);
+ caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels);
+ caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels);
/* KWS keyword that triggers ASR and associated checks */
- std::string triggerKeyword = std::string("yes");
+ std::string triggerKeyword = std::string("no");
if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) {
- caseContext.Set<const std::string &>("triggerkeyword", triggerKeyword);
+ caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword);
}
else {
printf_err("Selected trigger keyword not found in labels file\n");
@@ -196,50 +167,26 @@ void main_loop()
info("Main loop terminated.\n");
}
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model)
-{
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx];
- if (0 != inputCols % 3) {
- printf_err("Number of input columns is not a multiple of 3\n");
- }
- return std::max(inputCols/3, 0);
-}
-
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model)
+static bool VerifyTensorDimensions(const arm::app::Model& model)
{
+ /* Populate tensor related parameters. */
TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
- return std::max(inputRows, 0);
-}
-
-static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen)
-{
- const uint32_t inputRows = GetNumMfccFeatureVectors(model);
- const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
-
- /* Check to make sure that the input tensor supports the above context and inner lengths. */
- if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
- printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
- inputCtxLen);
- return 0;
+ if (!inputTensor->dims) {
+ printf_err("Invalid input tensor dims\n");
+ return false;
+ } else if (inputTensor->dims->size < 3) {
+ printf_err("Input tensor dimension should be >= 3\n");
+ return false;
}
TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
-
- const float tensorColRatio = static_cast<float>(inputRows)/
- static_cast<float>(outputRows);
-
- return std::round(static_cast<float>(inputCtxLen)/tensorColRatio);
-}
+ if (!outputTensor->dims) {
+ printf_err("Invalid output tensor dims\n");
+ return false;
+ } else if (outputTensor->dims->size < 3) {
+ printf_err("Output tensor dimension should be >= 3\n");
+ return false;
+ }
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- const uint32_t outputCtxLen)
-{
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
- return (outputRows - (2 * outputCtxLen));
+ return true;
}
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index 1e1a400..01aefae 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -28,6 +28,7 @@
#include "Wav2LetterMfcc.hpp"
#include "Wav2LetterPreprocess.hpp"
#include "Wav2LetterPostprocess.hpp"
+#include "KwsProcessing.hpp"
#include "AsrResult.hpp"
#include "AsrClassifier.hpp"
#include "OutputDecode.hpp"
@@ -39,11 +40,6 @@ using KwsClassifier = arm::app::Classifier;
namespace arm {
namespace app {
- enum AsrOutputReductionAxis {
- AxisRow = 1,
- AxisCol = 2
- };
-
struct KWSOutput {
bool executionSuccess = false;
const int16_t* asrAudioStart = nullptr;
@@ -51,73 +47,53 @@ namespace app {
};
/**
- * @brief Presents kws inference results using the data presentation
- * object.
- * @param[in] results vector of classification results to be displayed
- * @return true if successful, false otherwise
+ * @brief Presents KWS inference results.
+ * @param[in] results Vector of KWS classification results to be displayed.
+ * @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results);
+ static bool PresentInferenceResult(std::vector<kws::KwsResult>& results);
/**
- * @brief Presents asr inference results using the data presentation
- * object.
- * @param[in] platform reference to the hal platform object
- * @param[in] results vector of classification results to be displayed
- * @return true if successful, false otherwise
+ * @brief Presents ASR inference results.
+ * @param[in] results Vector of ASR classification results to be displayed.
+ * @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results);
+ static bool PresentInferenceResult(std::vector<asr::AsrResult>& results);
/**
- * @brief Returns a function to perform feature calculation and populates input tensor data with
- * MFCC data.
- *
- * Input tensor data type check is performed to choose correct MFCC feature data type.
- * If tensor has an integer data type then original features are quantised.
- *
- * Warning: mfcc calculator provided as input must have the same life scope as returned function.
- *
- * @param[in] mfcc MFCC feature calculator.
- * @param[in,out] inputTensor Input tensor pointer to store calculated features.
- * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
- *
- * @return function function to be called providing audio sample and sliding window index.
+ * @brief Performs the KWS pipeline.
+ * @param[in,out] ctx pointer to the application context object
+ * @return struct containing pointer to audio data where ASR should begin
+ * and how much data to process.
**/
- static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::MicroNetMFCC& mfcc,
- TfLiteTensor* inputTensor,
- size_t cacheSize);
+ static KWSOutput doKws(ApplicationContext& ctx)
+ {
+ auto& profiler = ctx.Get<Profiler&>("profiler");
+ auto& kwsModel = ctx.Get<Model&>("kwsModel");
+ const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength");
+ const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride");
+ const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold");
+
+ auto currentIndex = ctx.Get<uint32_t>("clipIndex");
- /**
- * @brief Performs the KWS pipeline.
- * @param[in,out] ctx pointer to the application context object
- *
- * @return KWSOutput struct containing pointer to audio data where ASR should begin
- * and how much data to process.
- */
- static KWSOutput doKws(ApplicationContext& ctx) {
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
constexpr int minTensorDims = static_cast<int>(
- (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
- arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
+ (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)?
+ MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx);
- KWSOutput output;
+ /* Output struct from doing KWS. */
+ KWSOutput output {};
- auto& profiler = ctx.Get<Profiler&>("profiler");
- auto& kwsModel = ctx.Get<Model&>("kwsmodel");
if (!kwsModel.IsInited()) {
printf_err("KWS model has not been initialised\n");
return output;
}
- const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
- const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
- const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
-
- TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
+ /* Get Input and Output tensors for pre/post processing. */
TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
-
+ TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
if (!kwsInputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return output;
@@ -126,63 +102,32 @@ namespace app {
return output;
}
- const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
- const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
-
- audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
- kwsMfcc.Init();
-
- /* Deduce the data length required for 1 KWS inference from the network parameters. */
- auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
- (kwsFrameLength - kwsFrameStride);
- auto kwsMfccWindowSize = kwsFrameLength;
- auto kwsMfccWindowStride = kwsFrameStride;
-
- /* We are choosing to move by half the window size => for a 1 second window size,
- * this means an overlap of 0.5 seconds. */
- auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
-
- info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
-
- /* Stride must be multiple of mfcc features window stride to re-use features. */
- if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
- kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
- }
-
- auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
+ /* Get input shape for feature extraction. */
+ TfLiteIntArray* inputShape = kwsModel.GetInputShape(0);
+ const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
+ const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx];
/* We expect to be sampling 1 second worth of data at a time
* NOTE: This is only used for time stamp calculation. */
- const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
+ const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
- auto currentIndex = ctx.Get<uint32_t>("clipIndex");
+ /* Set up pre and post-processing. */
+ KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames,
+ kwsMfccFrameLength, kwsMfccFrameStride);
- /* Creating a mfcc features sliding window for the data required for 1 inference. */
- auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
- get_audio_array(currentIndex),
- kwsAudioDataWindowSize, kwsMfccWindowSize,
- kwsMfccWindowStride);
+ std::vector<ClassificationResult> singleInfResult;
+ KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"),
+ ctx.Get<std::vector<std::string>&>("kwsLabels"),
+ singleInfResult);
/* Creating a sliding window through the whole audio clip. */
auto audioDataSlider = audio::SlidingWindow<const int16_t>(
get_audio_array(currentIndex),
get_audio_array_size(currentIndex),
- kwsAudioDataWindowSize, kwsAudioDataStride);
-
- /* Calculate number of the feature vectors in the window overlap region.
- * These feature vectors will be reused.*/
- size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
- - kwsMfccVectorsInAudioStride;
+ preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride);
- auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
- numberOfReusedFeatureVectors);
-
- if (!kwsMfccFeatureCalc){
- return output;
- }
-
- /* Container for KWS results. */
- std::vector<arm::app::kws::KwsResult> kwsResults;
+ /* Declare a container to hold kws results from across the whole audio clip. */
+ std::vector<kws::KwsResult> finalResults;
/* Display message on the LCD - inference running. */
std::string str_inf{"Running KWS inference... "};
@@ -197,70 +142,56 @@ namespace app {
while (audioDataSlider.HasNext()) {
const int16_t* inferenceWindow = audioDataSlider.Next();
- /* We moved to the next window - set the features sliding to the new address. */
- kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
-
/* The first window does not have cache ready. */
- bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
-
- /* Start calculating features inside one audio sliding window. */
- while (kwsAudioMFCCWindowSlider.HasNext()) {
- const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
- std::vector<int16_t> kwsMfccAudioData =
- std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
-
- /* Compute features for this window and write them to input tensor. */
- kwsMfccFeatureCalc(kwsMfccAudioData,
- kwsAudioMFCCWindowSlider.Index(),
- useCache,
- kwsMfccVectorsInAudioStride);
- }
+ preProcess.m_audioWindowIndex = audioDataSlider.Index();
- info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
- audioDataSlider.TotalStrides() + 1);
+ /* Run the pre-processing, inference and post-processing. */
+ if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
+ printf_err("KWS Pre-processing failed.");
+ return output;
+ }
- /* Run inference over this audio clip sliding window. */
if (!RunInference(kwsModel, profiler)) {
- printf_err("KWS inference failed\n");
+ printf_err("KWS Inference failed.");
return output;
}
- std::vector<ClassificationResult> kwsClassificationResult;
- auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
+ if (!postProcess.DoPostProcess()) {
+ printf_err("KWS Post-processing failed.");
+ return output;
+ }
- kwsClassifier.GetClassificationResults(
- kwsOutputTensor, kwsClassificationResult,
- ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
+ info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
+ audioDataSlider.TotalStrides() + 1);
- kwsResults.emplace_back(
- kws::KwsResult(
- kwsClassificationResult,
- audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
- audioDataSlider.Index(), kwsScoreThreshold)
- );
+ /* Add results from this window to our final results vector. */
+ finalResults.emplace_back(
+ kws::KwsResult(singleInfResult,
+ audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride,
+ audioDataSlider.Index(), kwsScoreThreshold));
- /* Keyword detected. */
- if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) {
- output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
+ /* Break out when trigger keyword is detected. */
+ if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword")
+ && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) {
+ output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize;
output.asrAudioSamples = get_audio_array_size(currentIndex) -
(audioDataSlider.NextWindowStartIndex() -
- kwsAudioDataStride + kwsAudioDataWindowSize);
+ preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize);
break;
}
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(kwsOutputTensor);
+ DumpTensor(kwsOutputTensor);
#endif /* VERIFY_TEST_OUTPUT */
} /* while (audioDataSlider.HasNext()) */
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
- if (!PresentInferenceResult(kwsResults)) {
+ if (!PresentInferenceResult(finalResults)) {
return output;
}
@@ -271,41 +202,41 @@ namespace app {
}
/**
- * @brief Performs the ASR pipeline.
- *
- * @param[in,out] ctx pointer to the application context object
- * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
- * and how much data to process
- * @return bool true if pipeline executed without failure
- */
- static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
+ * @brief Performs the ASR pipeline.
+ * @param[in,out] ctx Pointer to the application context object.
+ * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin
+ * and how much data to process.
+ * @return true if pipeline executed without failure.
+ **/
+ static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput)
+ {
+ auto& asrModel = ctx.Get<Model&>("asrModel");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
+ auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength");
+ auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride");
+ auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold");
+ auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
+
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
- auto& profiler = ctx.Get<Profiler&>("profiler");
- hal_lcd_clear(COLOR_BLACK);
-
- /* Get model reference. */
- auto& asrModel = ctx.Get<Model&>("asrmodel");
if (!asrModel.IsInited()) {
printf_err("ASR model has not been initialised\n");
return false;
}
- /* Get score threshold to be applied for the classifier (post-inference). */
- auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
+ hal_lcd_clear(COLOR_BLACK);
- /* Dimensions of the tensor should have been verified by the callee. */
+ /* Get Input and Output tensors for pre/post processing. */
TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
- const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
- /* Populate ASR MFCC related parameters. */
- auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
- auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
+ /* Get input shape. Dimensions of the tensor should have been verified by
+ * the callee. */
+ TfLiteIntArray* inputShape = asrModel.GetInputShape(0);
- /* Populate ASR inference context and inner lengths for input. */
- auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
+
+ const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx];
const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
/* Make sure the input tensor supports the above context and inner lengths. */
@@ -316,18 +247,9 @@ namespace app {
}
/* Audio data stride corresponds to inputInnerLen feature vectors. */
- const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
- asrMfccParamsWinStride + (asrMfccParamsWinLen);
- const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
- const float asrAudioParamsSecondsPerSample =
- (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
-
- /* Get pre/post-processing objects */
- auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
- auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
-
- /* Set default reduction axis for post-processing. */
- const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
+ const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen);
+ const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride;
+ const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq;
/* Get the remaining audio buffer and respective size from KWS results. */
const int16_t* audioArr = kwsOutput.asrAudioStart;
@@ -335,9 +257,9 @@ namespace app {
/* Audio clip must have enough samples to produce 1 MFCC feature. */
std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
- if (audioArrSize < asrMfccParamsWinLen) {
+ if (audioArrSize < asrMfccFrameLen) {
printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
- asrMfccParamsWinLen);
+ asrMfccFrameLen);
return false;
}
@@ -345,26 +267,38 @@ namespace app {
auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
audioBuffer.data(),
audioBuffer.size(),
- asrAudioParamsWinLen,
- asrAudioParamsWinStride);
+ asrAudioDataWindowLen,
+ asrAudioDataWindowStride);
/* Declare a container for results. */
- std::vector<arm::app::asr::AsrResult> asrResults;
+ std::vector<asr::AsrResult> asrResults;
/* Display message on the LCD - inference running. */
std::string str_inf{"Running ASR inference... "};
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
- size_t asrInferenceWindowLen = asrAudioParamsWinLen;
-
+ size_t asrInferenceWindowLen = asrAudioDataWindowLen;
+
+ /* Set up pre and post-processing objects. */
+ AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures,
+ inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
+ asrMfccFrameLen, asrMfccFrameStride);
+
+ std::vector<ClassificationResult> singleInfResult;
+ const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen);
+ AsrPostProcess asrPostProcess = AsrPostProcess(
+ asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"),
+ ctx.Get<std::vector<std::string>&>("asrLabels"),
+ singleInfResult, outputCtxLen,
+ Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
+ );
/* Start sliding through audio clip. */
while (audioDataSlider.HasNext()) {
/* If not enough audio see how much can be sent for processing. */
size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
- if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
+ if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) {
asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
}
@@ -373,8 +307,11 @@ namespace app {
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
- /* Calculate MFCCs, deltas and populate the input tensor. */
- asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
+ /* Run the pre-processing, inference and post-processing. */
+ if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) {
+ printf_err("ASR pre-processing failed.");
+ return false;
+ }
/* Run inference over this audio clip sliding window. */
if (!RunInference(asrModel, profiler)) {
@@ -382,24 +319,28 @@ namespace app {
return false;
}
- /* Post-process. */
- asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
+ /* Post processing needs to know if we are on the last audio window. */
+ asrPostProcess.m_lastIteration = !audioDataSlider.HasNext();
+ if (!asrPostProcess.DoPostProcess()) {
+ printf_err("ASR post-processing failed.");
+ return false;
+ }
/* Get results. */
std::vector<ClassificationResult> asrClassificationResult;
- auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
+ auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier");
asrClassifier.GetClassificationResults(
asrOutputTensor, asrClassificationResult,
- ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
+ ctx.Get<std::vector<std::string>&>("asrLabels"), 1);
asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
(audioDataSlider.Index() *
asrAudioParamsSecondsPerSample *
- asrAudioParamsWinStride),
+ asrAudioDataWindowStride),
audioDataSlider.Index(), asrScoreThreshold));
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
+ armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
#endif /* VERIFY_TEST_OUTPUT */
/* Erase */
@@ -417,7 +358,7 @@ namespace app {
return true;
}
- /* Audio inference classification handler. */
+ /* KWS and ASR inference handler. */
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
hal_lcd_clear(COLOR_BLACK);
@@ -434,13 +375,14 @@ namespace app {
do {
KWSOutput kwsOutput = doKws(ctx);
if (!kwsOutput.executionSuccess) {
+ printf_err("KWS failed\n");
return false;
}
if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
- info("Keyword spotted\n");
+ info("Trigger keyword spotted\n");
if(!doAsr(ctx, kwsOutput)) {
- printf_err("ASR failed");
+ printf_err("ASR failed\n");
return false;
}
}
@@ -452,7 +394,6 @@ namespace app {
return true;
}
-
static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
{
constexpr uint32_t dataPsnTxtStartX1 = 20;
@@ -464,33 +405,31 @@ namespace app {
/* Display each result. */
uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
- for (uint32_t i = 0; i < results.size(); ++i) {
-
+ for (auto & result : results) {
std::string topKeyword{"<none>"};
float score = 0.f;
- if (!results[i].m_resultVec.empty()) {
- topKeyword = results[i].m_resultVec[0].m_label;
- score = results[i].m_resultVec[0].m_normalisedVal;
+ if (!result.m_resultVec.empty()) {
+ topKeyword = result.m_resultVec[0].m_label;
+ score = result.m_resultVec[0].m_normalisedVal;
}
std::string resultStr =
- std::string{"@"} + std::to_string(results[i].m_timeStamp) +
+ std::string{"@"} + std::to_string(result.m_timeStamp) +
std::string{"s: "} + topKeyword + std::string{" ("} +
std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
- hal_lcd_display_text(
- resultStr.c_str(), resultStr.size(),
- dataPsnTxtStartX1, rowIdx1, 0);
+ hal_lcd_display_text(resultStr.c_str(), resultStr.size(),
+ dataPsnTxtStartX1, rowIdx1, 0);
rowIdx1 += dataPsnTxtYIncr;
info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
- results[i].m_timeStamp, results[i].m_inferenceNumber,
- results[i].m_threshold);
- for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
+ result.m_timeStamp, result.m_inferenceNumber,
+ result.m_threshold);
+ for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
- results[i].m_resultVec[j].m_label.c_str(),
- results[i].m_resultVec[j].m_normalisedVal);
+ result.m_resultVec[j].m_label.c_str(),
+ result.m_resultVec[j].m_normalisedVal);
}
}
@@ -523,143 +462,12 @@ namespace app {
std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
- hal_lcd_display_text(
- finalResultStr.c_str(), finalResultStr.size(),
- dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
+ hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
+ dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
info("Final result: %s\n", finalResultStr.c_str());
return true;
}
- /**
- * @brief Generic feature calculator factory.
- *
- * Returns lambda function to compute features using features cache.
- * Real features math is done by a lambda function provided as a parameter.
- * Features are written to input tensor memory.
- *
- * @tparam T feature vector type.
- * @param inputTensor model input tensor pointer.
- * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
- * @param compute features calculator function.
- * @return lambda function to compute features.
- **/
- template<class T>
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
- FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
- std::function<std::vector<T> (std::vector<int16_t>& )> compute)
- {
- /* Feature cache to be captured by lambda function. */
- static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
-
- return [=](std::vector<int16_t>& audioDataWindow,
- size_t index,
- bool useCache,
- size_t featuresOverlapIndex)
- {
- T* tensorData = tflite::GetTensorData<T>(inputTensor);
- std::vector<T> features;
-
- /* Reuse features from cache if cache is ready and sliding windows overlap.
- * Overlap is in the beginning of sliding window with a size of a feature cache.
- */
- if (useCache && index < featureCache.size()) {
- features = std::move(featureCache[index]);
- } else {
- features = std::move(compute(audioDataWindow));
- }
- auto size = features.size();
- auto sizeBytes = sizeof(T) * size;
- std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
-
- /* Start renewing cache as soon iteration goes out of the windows overlap. */
- if (index >= featuresOverlapIndex) {
- featureCache[index - featuresOverlapIndex] = std::move(features);
- }
- };
- }
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
-
- template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
- FeatureCalc<float>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
- static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
- {
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
-
- TfLiteQuantization quant = inputTensor->quantization;
-
- if (kTfLiteAffineQuantization == quant.type) {
-
- auto* quantParams = (TfLiteAffineQuantization*) quant.params;
- const float quantScale = quantParams->scale->data[0];
- const int quantOffset = quantParams->zero_point->data[0];
-
- switch (inputTensor->type) {
- case kTfLiteInt8: {
- mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
- cacheSize,
- [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
- quantScale,
- quantOffset);
- }
- );
- break;
- }
- case kTfLiteUInt8: {
- mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
- cacheSize,
- [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
- quantScale,
- quantOffset);
- }
- );
- break;
- }
- case kTfLiteInt16: {
- mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
- cacheSize,
- [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
- quantScale,
- quantOffset);
- }
- );
- break;
- }
- default:
- printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
- }
-
-
- } else {
- mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
- cacheSize,
- [&mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccCompute(audioDataWindow);
- });
- }
- return mfccFeatureCalc;
- }
} /* namespace app */
} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc
index 2a76b1b..42f434e 100644
--- a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc
+++ b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,62 +15,71 @@
* limitations under the License.
*/
#include "Wav2LetterPostprocess.hpp"
+
#include "Wav2LetterModel.hpp"
#include "log_macros.h"
+#include <cmath>
+
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
-
- Postprocess::Postprocess(const uint32_t contextLen,
- const uint32_t innerLen,
- const uint32_t blankTokenIdx)
- : m_contextLen(contextLen),
- m_innerLen(innerLen),
- m_totalLen(2 * this->m_contextLen + this->m_innerLen),
+
+ AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
+ const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
+ const uint32_t outputContextLen,
+ const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
+ ):
+ m_classifier(classifier),
+ m_outputTensor(outputTensor),
+ m_labels{labels},
+ m_results(results),
+ m_outputContextLen(outputContextLen),
m_countIterations(0),
- m_blankTokenIdx(blankTokenIdx)
- {}
+ m_blankTokenIdx(blankTokenIdx),
+ m_reductionAxisIdx(reductionAxisIdx)
+ {
+ this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
+ this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
+ }
- bool Postprocess::Invoke(TfLiteTensor* tensor,
- const uint32_t axisIdx,
- const bool lastIteration)
+ bool AsrPostProcess::DoPostProcess()
{
/* Basic checks. */
- if (!this->IsInputValid(tensor, axisIdx)) {
+ if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
return false;
}
/* Irrespective of tensor type, we use unsigned "byte" */
- uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
- const uint32_t elemSz = this->GetTensorElementSize(tensor);
+ auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
+ const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor);
/* Other sanity checks. */
if (0 == elemSz) {
printf_err("Tensor type not supported for post processing\n");
return false;
- } else if (elemSz * this->m_totalLen > tensor->bytes) {
+ } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) {
printf_err("Insufficient number of tensor bytes\n");
return false;
}
/* Which axis do we need to process? */
- switch (axisIdx) {
- case arm::app::Wav2LetterModel::ms_outputRowsIdx:
- return this->EraseSectionsRowWise(ptrData,
- elemSz *
- tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
- lastIteration);
+ switch (this->m_reductionAxisIdx) {
+ case Wav2LetterModel::ms_outputRowsIdx:
+ this->EraseSectionsRowWise(
+ ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx],
+ this->m_lastIteration);
+ break;
default:
- printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx);
+ printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx);
+ return false;
}
+ this->m_classifier.GetClassificationResults(this->m_outputTensor,
+ this->m_results, this->m_labels, 1);
- return false;
+ return true;
}
- bool Postprocess::IsInputValid(TfLiteTensor* tensor,
- const uint32_t axisIdx) const
+ bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
{
if (nullptr == tensor) {
return false;
@@ -84,25 +93,23 @@ namespace asr {
if (static_cast<int>(this->m_totalLen) !=
tensor->dims->data[axisIdx]) {
- printf_err("Unexpected tensor dimension for axis %d, \n",
- tensor->dims->data[axisIdx]);
+ printf_err("Unexpected tensor dimension for axis %d, got %d, \n",
+ axisIdx, tensor->dims->data[axisIdx]);
return false;
}
return true;
}
- uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor)
+ uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
{
switch(tensor->type) {
case kTfLiteUInt8:
- return 1;
case kTfLiteInt8:
return 1;
case kTfLiteInt16:
return 2;
case kTfLiteInt32:
- return 4;
case kTfLiteFloat32:
return 4;
default:
@@ -113,30 +120,30 @@ namespace asr {
return 0;
}
- bool Postprocess::EraseSectionsRowWise(
- uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration)
+ bool AsrPostProcess::EraseSectionsRowWise(
+ uint8_t* ptrData,
+ const uint32_t strideSzBytes,
+ const bool lastIteration)
{
/* In this case, the "zero-ing" is quite simple as the region
* to be zeroed sits in contiguous memory (row-major). */
- const uint32_t eraseLen = strideSzBytes * this->m_contextLen;
+ const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen;
/* Erase left context? */
if (this->m_countIterations > 0) {
/* Set output of each classification window to the blank token. */
std::memset(ptrData, 0, eraseLen);
- for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+ for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
}
}
/* Erase right context? */
if (false == lastIteration) {
- uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen));
+ uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen));
/* Set output of each classification window to the blank token. */
std::memset(rightCtxPtr, 0, eraseLen);
- for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+ for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
}
}
@@ -150,7 +157,58 @@ namespace asr {
return true;
}
-} /* namespace asr */
-} /* namespace audio */
+ uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model)
+ {
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
+ if (inputRows == 0) {
+ printf_err("Error getting number of input rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_inputRowsIdx);
+ }
+ return inputRows;
+ }
+
+ uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
+ {
+ const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
+ if (outputRows == 0) {
+ printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_outputRowsIdx);
+ }
+
+ /* Watching for underflow. */
+ int innerLen = (outputRows - (2 * outputCtxLen));
+
+ return std::max(innerLen, 0);
+ }
+
+ uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
+ {
+ const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model);
+ const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+ constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
+
+ /* Check to make sure that the input tensor supports the above
+ * context and inner lengths. */
+ if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
+ printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
+ inputCtxLen);
+ return 0;
+ }
+
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+ const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
+ if (outputRows == 0) {
+ printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_outputRowsIdx);
+ return 0;
+ }
+
+ const float inOutRowRatio = static_cast<float>(inputRows) /
+ static_cast<float>(outputRows);
+
+ return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio);
+ }
+
} /* namespace app */
} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc
index d3f3579..92b0631 100644
--- a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc
+++ b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,41 +20,35 @@
#include "TensorFlowLiteMicro.hpp"
#include <algorithm>
-#include <math.h>
+#include <cmath>
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
-
- Preprocess::Preprocess(
- const uint32_t numMfccFeatures,
- const uint32_t windowLen,
- const uint32_t windowStride,
- const uint32_t numMfccVectors):
- m_mfcc(numMfccFeatures, windowLen),
- m_mfccBuf(numMfccFeatures, numMfccVectors),
- m_delta1Buf(numMfccFeatures, numMfccVectors),
- m_delta2Buf(numMfccFeatures, numMfccVectors),
- m_windowLen(windowLen),
- m_windowStride(windowStride),
+
+ AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
+ const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
+ const uint32_t mfccWindowStride
+ ):
+ m_mfcc(numMfccFeatures, mfccWindowLen),
+ m_inputTensor(inputTensor),
+ m_mfccBuf(numMfccFeatures, numFeatureFrames),
+ m_delta1Buf(numMfccFeatures, numFeatureFrames),
+ m_delta2Buf(numMfccFeatures, numFeatureFrames),
+ m_mfccWindowLen(mfccWindowLen),
+ m_mfccWindowStride(mfccWindowStride),
m_numMfccFeats(numMfccFeatures),
- m_numFeatVectors(numMfccVectors),
- m_window()
+ m_numFeatureFrames(numFeatureFrames)
{
- if (numMfccFeatures > 0 && windowLen > 0) {
+ if (numMfccFeatures > 0 && mfccWindowLen > 0) {
this->m_mfcc.Init();
}
}
- bool Preprocess::Invoke(
- const int16_t* audioData,
- const uint32_t audioDataLen,
- TfLiteTensor* tensor)
+ bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
{
- this->m_window = SlidingWindow<const int16_t>(
- audioData, audioDataLen,
- this->m_windowLen, this->m_windowStride);
+ this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(
+ static_cast<const int16_t*>(audioData), audioDataLen,
+ this->m_mfccWindowLen, this->m_mfccWindowStride);
uint32_t mfccBufIdx = 0;
@@ -62,12 +56,12 @@ namespace asr {
std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f);
std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f);
- /* While we can slide over the window. */
- while (this->m_window.HasNext()) {
- const int16_t* mfccWindow = this->m_window.Next();
+ /* While we can slide over the audio. */
+ while (this->m_mfccSlidingWindow.HasNext()) {
+ const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
auto mfccAudioData = std::vector<int16_t>(
mfccWindow,
- mfccWindow + this->m_windowLen);
+ mfccWindow + this->m_mfccWindowLen);
auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData);
for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) {
this->m_mfccBuf(i, mfccBufIdx) = mfcc[i];
@@ -76,11 +70,11 @@ namespace asr {
}
/* Pad MFCC if needed by adding MFCC for zeros. */
- if (mfccBufIdx != this->m_numFeatVectors) {
- std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_windowLen, 0);
+ if (mfccBufIdx != this->m_numFeatureFrames) {
+ std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_mfccWindowLen, 0);
std::vector<float> mfccZeros = this->m_mfcc.MfccCompute(zerosWindow);
- while (mfccBufIdx != this->m_numFeatVectors) {
+ while (mfccBufIdx != this->m_numFeatureFrames) {
memcpy(&this->m_mfccBuf(0, mfccBufIdx),
mfccZeros.data(), sizeof(float) * m_numMfccFeats);
++mfccBufIdx;
@@ -88,41 +82,39 @@ namespace asr {
}
/* Compute first and second order deltas from MFCCs. */
- this->ComputeDeltas(this->m_mfccBuf,
- this->m_delta1Buf,
- this->m_delta2Buf);
+ AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
- /* Normalise. */
- this->Normalise();
+ /* Standardize calculated features. */
+ this->Standarize();
/* Quantise. */
- QuantParams quantParams = GetTensorQuantParams(tensor);
+ QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor);
if (0 == quantParams.scale) {
printf_err("Quantisation scale can't be 0\n");
return false;
}
- switch(tensor->type) {
+ switch(this->m_inputTensor->type) {
case kTfLiteUInt8:
return this->Quantise<uint8_t>(
- tflite::GetTensorData<uint8_t>(tensor), tensor->bytes,
+ tflite::GetTensorData<uint8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
quantParams.scale, quantParams.offset);
case kTfLiteInt8:
return this->Quantise<int8_t>(
- tflite::GetTensorData<int8_t>(tensor), tensor->bytes,
+ tflite::GetTensorData<int8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
quantParams.scale, quantParams.offset);
default:
printf_err("Unsupported tensor type %s\n",
- TfLiteTypeGetName(tensor->type));
+ TfLiteTypeGetName(this->m_inputTensor->type));
}
return false;
}
- bool Preprocess::ComputeDeltas(Array2d<float>& mfcc,
- Array2d<float>& delta1,
- Array2d<float>& delta2)
+ bool AsrPreProcess::ComputeDeltas(Array2d<float>& mfcc,
+ Array2d<float>& delta1,
+ Array2d<float>& delta2)
{
const std::vector <float> delta1Coeffs =
{6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
@@ -148,11 +140,11 @@ namespace asr {
/* Iterate through features in MFCC vector. */
for (size_t i = 0; i < numFeatures; ++i) {
/* For each feature, iterate through time (t) samples representing feature evolution and
- * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels.
+ * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels.
* Convolution padding = valid, result size is `time length - kernel length + 1`.
* The result is padded with 0 from both sides to match the size of initial time samples data.
*
- * For the small filter, conv1d implementation as a simple loop is efficient enough.
+ * For the small filter, conv1D implementation as a simple loop is efficient enough.
* Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32.
*/
@@ -175,20 +167,10 @@ namespace asr {
return true;
}
- float Preprocess::GetMean(Array2d<float>& vec)
- {
- return math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
- }
-
- float Preprocess::GetStdDev(Array2d<float>& vec, const float mean)
- {
- return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
- }
-
- void Preprocess::NormaliseVec(Array2d<float>& vec)
+ void AsrPreProcess::StandardizeVecF32(Array2d<float>& vec)
{
- auto mean = Preprocess::GetMean(vec);
- auto stddev = Preprocess::GetStdDev(vec, mean);
+ auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
+ auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
debug("Mean: %f, Stddev: %f\n", mean, stddev);
if (stddev == 0) {
@@ -204,14 +186,14 @@ namespace asr {
}
}
- void Preprocess::Normalise()
+ void AsrPreProcess::Standarize()
{
- Preprocess::NormaliseVec(this->m_mfccBuf);
- Preprocess::NormaliseVec(this->m_delta1Buf);
- Preprocess::NormaliseVec(this->m_delta2Buf);
+ AsrPreProcess::StandardizeVecF32(this->m_mfccBuf);
+ AsrPreProcess::StandardizeVecF32(this->m_delta1Buf);
+ AsrPreProcess::StandardizeVecF32(this->m_delta2Buf);
}
- float Preprocess::GetQuantElem(
+ float AsrPreProcess::GetQuantElem(
const float elem,
const float quantScale,
const int quantOffset,
@@ -222,7 +204,5 @@ namespace asr {
return std::min<float>(std::max<float>(val, minVal), maxVal);
}
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake
index b3fe020..40df4d7 100644
--- a/source/use_case/kws_asr/usecase.cmake
+++ b/source/use_case/kws_asr/usecase.cmake
@@ -1,5 +1,5 @@
#----------------------------------------------------------------------------
-# Copyright (c) 2021 Arm Limited. All rights reserved.
+# Copyright (c) 2021-2022 Arm Limited. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -59,7 +59,7 @@ USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen
STRING)
USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_KWS "Specify the score threshold [0.0, 1.0) that must be applied to the KWS results for a label to be deemed valid."
- 0.9
+ 0.7
STRING)
USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [0.0, 1.0) that must be applied to the ASR results for a label to be deemed valid."
diff --git a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp
index c188e42..cbf0e4e 100644
--- a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp
+++ b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+#ifndef RNNOISE_FEATURE_PROCESSOR_HPP
+#define RNNOISE_FEATURE_PROCESSOR_HPP
+
#include "PlatformMath.hpp"
#include <cstdint>
#include <vector>
@@ -47,11 +50,11 @@ namespace rnn {
* - https://jmvalin.ca/demo/rnnoise/
* - https://arxiv.org/abs/1709.08243
**/
- class RNNoiseProcess {
+ class RNNoiseFeatureProcessor {
/* Public interface */
public:
- RNNoiseProcess();
- ~RNNoiseProcess() = default;
+ RNNoiseFeatureProcessor();
+ ~RNNoiseFeatureProcessor() = default;
/**
* @brief Calculates the features from a given audio buffer ready to be sent to RNNoise model.
@@ -328,10 +331,11 @@ namespace rnn {
const std::array <uint32_t, NB_BANDS> m_eband5ms {
0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12,
14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100};
-
};
} /* namespace rnn */
-} /* namspace app */
+} /* namespace app */
} /* namespace arm */
+
+#endif /* RNNOISE_FEATURE_PROCESSOR_HPP */
diff --git a/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp
new file mode 100644
index 0000000..15e62d9
--- /dev/null
+++ b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef RNNOISE_PROCESSING_HPP
+#define RNNOISE_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
+
+namespace arm {
+namespace app {
+
+ /**
+ * @brief Pre-processing class for Noise Reduction use case.
+ * Implements methods declared by BasePreProcess and anything else needed
+ * to populate input tensors ready for inference.
+ */
+ class RNNoisePreProcess : public BasePreProcess {
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in/out] featureProcessor RNNoise specific feature extractor object.
+ * @param[in/out] frameFeatures RNNoise specific features shared between pre & post-processing.
+ *
+ **/
+ explicit RNNoisePreProcess(TfLiteTensor* inputTensor,
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+ std::shared_ptr<rnn::FrameFeatures> frameFeatures);
+
+ /**
+ * @brief Should perform pre-processing of 'raw' input audio data and load it into
+ * TFLite Micro input tensors ready for inference
+ * @param[in] input Pointer to the data that pre-processing will work on.
+ * @param[in] inputSize Size of the input data.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ private:
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor; /* RNNoise feature processor shared between pre & post-processing. */
+ std::shared_ptr<rnn::FrameFeatures> m_frameFeatures; /* RNNoise features shared between pre & post-processing. */
+ rnn::vec1D32F m_audioFrame; /* Audio frame cast to FP32 */
+
+ /**
+ * @brief Quantize the given features and populate the input Tensor.
+ * @param[in] inputFeatures Vector of floating point features to quantize.
+ * @param[in] quantScale Quantization scale for the inputTensor.
+ * @param[in] quantOffset Quantization offset for the inputTensor.
+ * @param[in,out] inputTensor TFLite micro tensor to populate.
+ **/
+ static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
+ float quantScale, int quantOffset,
+ TfLiteTensor* inputTensor);
+ };
+
+ /**
+ * @brief Post-processing class for Noise Reduction use case.
+ * Implements methods declared by BasePostProcess and anything else needed
+ * to populate result vector.
+ */
+ class RNNoisePostProcess : public BasePostProcess {
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[out] denoisedAudioFrame Vector to store the final denoised audio frame.
+ * @param[in/out] featureProcessor RNNoise specific feature extractor object.
+ * @param[in/out] frameFeatures RNNoise specific features shared between pre & post-processing.
+ **/
+ RNNoisePostProcess(TfLiteTensor* outputTensor,
+ std::vector<int16_t>& denoisedAudioFrame,
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+ std::shared_ptr<rnn::FrameFeatures> frameFeatures);
+
+ /**
+ * @brief Should perform post-processing of the result of inference then
+ * populate result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+
+ private:
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ std::vector<int16_t>& m_denoisedAudioFrame; /* Vector to store the final denoised frame. */
+ rnn::vec1D32F m_denoisedAudioFrameFloat; /* Internal vector to store the final denoised frame (FP32). */
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor; /* RNNoise feature processor shared between pre & post-processing. */
+ std::shared_ptr<rnn::FrameFeatures> m_frameFeatures; /* RNNoise features shared between pre & post-processing. */
+ std::vector<float> m_modelOutputFloat; /* Internal vector to store de-quantized model output. */
+
+ };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* RNNOISE_PROCESSING_HPP */ \ No newline at end of file
diff --git a/source/use_case/noise_reduction/src/MainLoop.cc b/source/use_case/noise_reduction/src/MainLoop.cc
index 5fd7823..fd72127 100644
--- a/source/use_case/noise_reduction/src/MainLoop.cc
+++ b/source/use_case/noise_reduction/src/MainLoop.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,12 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "hal.h" /* Brings in platform definitions. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
#include "RNNoiseModel.hpp" /* Model class for running inference. */
#include "InputFiles.hpp" /* For input audio clips. */
-#include "RNNoiseProcess.hpp" /* Pre-processing class */
#include "log_macros.h"
enum opcodes
diff --git a/source/use_case/noise_reduction/src/RNNoiseProcess.cc b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc
index 4c568fa..036894c 100644
--- a/source/use_case/noise_reduction/src/RNNoiseProcess.cc
+++ b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "RNNoiseProcess.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
#include "log_macros.h"
#include <algorithm>
@@ -33,7 +33,7 @@ do { \
} \
} while(0)
-RNNoiseProcess::RNNoiseProcess() :
+RNNoiseFeatureProcessor::RNNoiseFeatureProcessor() :
m_halfWindow(FRAME_SIZE, 0),
m_dctTable(NB_BANDS * NB_BANDS),
m_analysisMem(FRAME_SIZE, 0),
@@ -54,9 +54,9 @@ RNNoiseProcess::RNNoiseProcess() :
this->InitTables();
}
-void RNNoiseProcess::PreprocessFrame(const float* audioData,
- const size_t audioLen,
- FrameFeatures& features)
+void RNNoiseFeatureProcessor::PreprocessFrame(const float* audioData,
+ const size_t audioLen,
+ FrameFeatures& features)
{
/* Note audioWindow is modified in place */
const arrHp aHp {-1.99599, 0.99600 };
@@ -68,7 +68,7 @@ void RNNoiseProcess::PreprocessFrame(const float* audioData,
this->ComputeFrameFeatures(audioWindow, features);
}
-void RNNoiseProcess::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame)
+void RNNoiseFeatureProcessor::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame)
{
std::vector<float> outputBands = modelOutput;
std::vector<float> gain(FREQ_SIZE, 0);
@@ -92,7 +92,7 @@ void RNNoiseProcess::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& feat
FrameSynthesis(outFrame, features.m_fftX);
}
-void RNNoiseProcess::InitTables()
+void RNNoiseFeatureProcessor::InitTables()
{
constexpr float pi = M_PI;
constexpr float halfPi = M_PI / 2;
@@ -111,7 +111,7 @@ void RNNoiseProcess::InitTables()
}
}
-void RNNoiseProcess::BiQuad(
+void RNNoiseFeatureProcessor::BiQuad(
const arrHp& bHp,
const arrHp& aHp,
arrHp& memHpX,
@@ -126,8 +126,8 @@ void RNNoiseProcess::BiQuad(
}
}
-void RNNoiseProcess::ComputeFrameFeatures(vec1D32F& audioWindow,
- FrameFeatures& features)
+void RNNoiseFeatureProcessor::ComputeFrameFeatures(vec1D32F& audioWindow,
+ FrameFeatures& features)
{
this->FrameAnalysis(audioWindow,
features.m_fftX,
@@ -264,7 +264,7 @@ void RNNoiseProcess::ComputeFrameFeatures(vec1D32F& audioWindow,
features.m_featuresVec[NB_BANDS + 3 * NB_DELTA_CEPS + 1] = specVariability / CEPS_MEM - 2.1;
}
-void RNNoiseProcess::FrameAnalysis(
+void RNNoiseFeatureProcessor::FrameAnalysis(
const vec1D32F& audioWindow,
vec1D32F& fft,
vec1D32F& energy,
@@ -289,7 +289,7 @@ void RNNoiseProcess::FrameAnalysis(
ComputeBandEnergy(fft, energy);
}
-void RNNoiseProcess::ApplyWindow(vec1D32F& x)
+void RNNoiseFeatureProcessor::ApplyWindow(vec1D32F& x)
{
if (WINDOW_SIZE != x.size()) {
printf_err("Invalid size for vector to be windowed\n");
@@ -305,7 +305,7 @@ void RNNoiseProcess::ApplyWindow(vec1D32F& x)
}
}
-void RNNoiseProcess::ForwardTransform(
+void RNNoiseFeatureProcessor::ForwardTransform(
vec1D32F& x,
vec1D32F& fft)
{
@@ -327,7 +327,7 @@ void RNNoiseProcess::ForwardTransform(
* first half of the FFT's. The conjugates are not present. */
}
-void RNNoiseProcess::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE)
+void RNNoiseFeatureProcessor::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE)
{
bandE = vec1D32F(NB_BANDS, 0);
@@ -351,7 +351,7 @@ void RNNoiseProcess::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE)
bandE[NB_BANDS - 1] *= 2;
}
-void RNNoiseProcess::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC)
+void RNNoiseFeatureProcessor::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC)
{
bandC = vec1D32F(NB_BANDS, 0);
VERIFY(this->m_eband5ms.size() >= NB_BANDS);
@@ -374,7 +374,7 @@ void RNNoiseProcess::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D
bandC[NB_BANDS - 1] *= 2;
}
-void RNNoiseProcess::DCT(vec1D32F& input, vec1D32F& output)
+void RNNoiseFeatureProcessor::DCT(vec1D32F& input, vec1D32F& output)
{
VERIFY(this->m_dctTable.size() >= NB_BANDS * NB_BANDS);
for (uint32_t i = 0; i < NB_BANDS; ++i) {
@@ -387,7 +387,7 @@ void RNNoiseProcess::DCT(vec1D32F& input, vec1D32F& output)
}
}
-void RNNoiseProcess::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) {
+void RNNoiseFeatureProcessor::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) {
for (size_t i = 1; i < (pitchBufSz >> 1); ++i) {
pitchBuf[i] = 0.5 * (
0.5 * (this->m_pitchBuf[2 * i - 1] + this->m_pitchBuf[2 * i + 1])
@@ -431,7 +431,7 @@ void RNNoiseProcess::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) {
this->Fir5(lpc2, pitchBufSz >> 1, pitchBuf);
}
-int RNNoiseProcess::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) {
+int RNNoiseFeatureProcessor::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) {
uint32_t lag = len + maxPitch;
vec1D32F xLp4(len >> 2, 0);
vec1D32F yLp4(lag >> 2, 0);
@@ -488,7 +488,7 @@ int RNNoiseProcess::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32
return 2*bestPitch[0] - offset;
}
-arrHp RNNoiseProcess::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch)
+arrHp RNNoiseFeatureProcessor::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch)
{
float Syy = 1;
arrHp bestNum {-1, -1};
@@ -527,7 +527,7 @@ arrHp RNNoiseProcess::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len,
return bestPitch;
}
-int RNNoiseProcess::RemoveDoubling(
+int RNNoiseFeatureProcessor::RemoveDoubling(
vec1D32F& pitchBuf,
uint32_t maxPeriod,
uint32_t minPeriod,
@@ -679,12 +679,12 @@ int RNNoiseProcess::RemoveDoubling(
return this->m_lastPeriod;
}
-float RNNoiseProcess::ComputePitchGain(float xy, float xx, float yy)
+float RNNoiseFeatureProcessor::ComputePitchGain(float xy, float xx, float yy)
{
return xy / math::MathUtils::SqrtF32(1+xx*yy);
}
-void RNNoiseProcess::AutoCorr(
+void RNNoiseFeatureProcessor::AutoCorr(
const vec1D32F& x,
vec1D32F& ac,
size_t lag,
@@ -711,7 +711,7 @@ void RNNoiseProcess::AutoCorr(
}
-void RNNoiseProcess::PitchXCorr(
+void RNNoiseFeatureProcessor::PitchXCorr(
const vec1D32F& x,
const vec1D32F& y,
vec1D32F& xCorr,
@@ -728,7 +728,7 @@ void RNNoiseProcess::PitchXCorr(
}
/* Linear predictor coefficients */
-void RNNoiseProcess::LPC(
+void RNNoiseFeatureProcessor::LPC(
const vec1D32F& correlation,
int32_t p,
vec1D32F& lpc)
@@ -766,7 +766,7 @@ void RNNoiseProcess::LPC(
}
}
-void RNNoiseProcess::Fir5(
+void RNNoiseFeatureProcessor::Fir5(
const vec1D32F &num,
uint32_t N,
vec1D32F &x)
@@ -794,7 +794,7 @@ void RNNoiseProcess::Fir5(
}
}
-void RNNoiseProcess::PitchFilter(FrameFeatures &features, vec1D32F &gain) {
+void RNNoiseFeatureProcessor::PitchFilter(FrameFeatures &features, vec1D32F &gain) {
std::vector<float> r(NB_BANDS, 0);
std::vector<float> rf(FREQ_SIZE, 0);
std::vector<float> newE(NB_BANDS);
@@ -835,7 +835,7 @@ void RNNoiseProcess::PitchFilter(FrameFeatures &features, vec1D32F &gain) {
}
}
-void RNNoiseProcess::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) {
+void RNNoiseFeatureProcessor::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) {
std::vector<float> x(WINDOW_SIZE, 0);
InverseTransform(x, fftY);
ApplyWindow(x);
@@ -845,7 +845,7 @@ void RNNoiseProcess::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) {
memcpy((m_synthesisMem.data()), &x[FRAME_SIZE], FRAME_SIZE*sizeof(float));
}
-void RNNoiseProcess::InterpBandGain(vec1D32F& g, vec1D32F& bandE) {
+void RNNoiseFeatureProcessor::InterpBandGain(vec1D32F& g, vec1D32F& bandE) {
for (size_t i = 0; i < NB_BANDS - 1; i++) {
int bandSize = (m_eband5ms[i + 1] - m_eband5ms[i]) << FRAME_SIZE_SHIFT;
for (int j = 0; j < bandSize; j++) {
@@ -855,7 +855,7 @@ void RNNoiseProcess::InterpBandGain(vec1D32F& g, vec1D32F& bandE) {
}
}
-void RNNoiseProcess::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) {
+void RNNoiseFeatureProcessor::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) {
std::vector<float> x(WINDOW_SIZE * 2); /* This is complex. */
vec1D32F newFFT; /* This is complex. */
diff --git a/source/use_case/noise_reduction/src/RNNoiseProcessing.cc b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc
new file mode 100644
index 0000000..f6a3ec4
--- /dev/null
+++ b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "RNNoiseProcessing.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+ RNNoisePreProcess::RNNoisePreProcess(TfLiteTensor* inputTensor,
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor, std::shared_ptr<rnn::FrameFeatures> frameFeatures)
+ : m_inputTensor{inputTensor},
+ m_featureProcessor{featureProcessor},
+ m_frameFeatures{frameFeatures}
+ {}
+
+ bool RNNoisePreProcess::DoPreProcess(const void* data, size_t inputSize)
+ {
+ if (data == nullptr) {
+ printf_err("Data pointer is null");
+ return false;
+ }
+
+ auto input = static_cast<const int16_t*>(data);
+ this->m_audioFrame = rnn::vec1D32F(input, input + inputSize);
+ m_featureProcessor->PreprocessFrame(this->m_audioFrame.data(), inputSize, *this->m_frameFeatures);
+
+ QuantizeAndPopulateInput(this->m_frameFeatures->m_featuresVec,
+ this->m_inputTensor->params.scale, this->m_inputTensor->params.zero_point,
+ this->m_inputTensor);
+
+ debug("Input tensor populated \n");
+
+ return true;
+ }
+
+ void RNNoisePreProcess::QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
+ const float quantScale, const int quantOffset,
+ TfLiteTensor* inputTensor)
+ {
+ const float minVal = std::numeric_limits<int8_t>::min();
+ const float maxVal = std::numeric_limits<int8_t>::max();
+
+ auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor);
+
+ for (size_t i=0; i < inputFeatures.size(); ++i) {
+ float quantValue = ((inputFeatures[i] / quantScale) + quantOffset);
+ inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal));
+ }
+ }
+
+ RNNoisePostProcess::RNNoisePostProcess(TfLiteTensor* outputTensor,
+ std::vector<int16_t>& denoisedAudioFrame,
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+ std::shared_ptr<rnn::FrameFeatures> frameFeatures)
+ : m_outputTensor{outputTensor},
+ m_denoisedAudioFrame{denoisedAudioFrame},
+ m_featureProcessor{featureProcessor},
+ m_frameFeatures{frameFeatures}
+ {
+ this->m_denoisedAudioFrameFloat.reserve(denoisedAudioFrame.size());
+ this->m_modelOutputFloat.resize(outputTensor->bytes);
+ }
+
+ bool RNNoisePostProcess::DoPostProcess()
+ {
+ const auto* outputData = tflite::GetTensorData<int8_t>(this->m_outputTensor);
+ auto outputQuantParams = GetTensorQuantParams(this->m_outputTensor);
+
+ for (size_t i = 0; i < this->m_outputTensor->bytes; ++i) {
+ this->m_modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset)
+ * outputQuantParams.scale;
+ }
+
+ this->m_featureProcessor->PostProcessFrame(this->m_modelOutputFloat,
+ *this->m_frameFeatures, this->m_denoisedAudioFrameFloat);
+
+ for (size_t i = 0; i < this->m_denoisedAudioFrame.size(); ++i) {
+ this->m_denoisedAudioFrame[i] = static_cast<int16_t>(
+ std::roundf(this->m_denoisedAudioFrameFloat[i]));
+ }
+
+ return true;
+ }
+
+} /* namespace app */
+} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc
index acb8ba7..53bb43e 100644
--- a/source/use_case/noise_reduction/src/UseCaseHandler.cc
+++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc
@@ -21,12 +21,10 @@
#include "ImageUtils.hpp"
#include "InputFiles.hpp"
#include "RNNoiseModel.hpp"
-#include "RNNoiseProcess.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
+#include "RNNoiseProcessing.hpp"
#include "log_macros.h"
-#include <cmath>
-#include <algorithm>
-
namespace arm {
namespace app {
@@ -36,17 +34,6 @@ namespace app {
**/
static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
- /**
- * @brief Quantize the given features and populate the input Tensor.
- * @param[in] inputFeatures Vector of floating point features to quantize.
- * @param[in] quantScale Quantization scale for the inputTensor.
- * @param[in] quantOffset Quantization offset for the inputTensor.
- * @param[in,out] inputTensor TFLite micro tensor to populate.
- **/
- static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
- float quantScale, int quantOffset,
- TfLiteTensor* inputTensor);
-
/* Noise reduction inference handler. */
bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll)
{
@@ -57,7 +44,7 @@ namespace app {
size_t memDumpMaxLen = 0;
uint8_t* memDumpBaseAddr = nullptr;
size_t undefMemDumpBytesWritten = 0;
- size_t *pMemDumpBytesWritten = &undefMemDumpBytesWritten;
+ size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten;
if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
@@ -74,8 +61,8 @@ namespace app {
}
/* Populate Pre-Processing related parameters. */
- auto audioParamsWinLen = ctx.Get<uint32_t>("frameLength");
- auto audioParamsWinStride = ctx.Get<uint32_t>("frameStride");
+ auto audioFrameLen = ctx.Get<uint32_t>("frameLength");
+ auto audioFrameStride = ctx.Get<uint32_t>("frameStride");
auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures");
TfLiteTensor* inputTensor = model.GetInputTensor(0);
@@ -103,7 +90,7 @@ namespace app {
if (ctx.Has("featureFileNames")) {
audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
}
- do{
+ do {
hal_lcd_clear(COLOR_BLACK);
auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten;
@@ -112,32 +99,38 @@ namespace app {
/* Creating a sliding window through the audio. */
auto audioDataSlider = audio::SlidingWindow<const int16_t>(
audioAccessorFunc(currentIndex),
- audioSizeAccessorFunc(currentIndex), audioParamsWinLen,
- audioParamsWinStride);
+ audioSizeAccessorFunc(currentIndex), audioFrameLen,
+ audioFrameStride);
info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex,
audioFileAccessorFunc(currentIndex));
memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
- (audioDataSlider.TotalStrides() + 1) * audioParamsWinLen,
+ (audioDataSlider.TotalStrides() + 1) * audioFrameLen,
memDumpBaseAddr + memDumpBytesWritten,
memDumpMaxLen - memDumpBytesWritten);
- rnn::RNNoiseProcess featureProcessor = rnn::RNNoiseProcess();
- rnn::vec1D32F audioFrame(audioParamsWinLen);
- rnn::vec1D32F inputFeatures(nrNumInputFeatures);
- rnn::vec1D32F denoisedAudioFrameFloat(audioParamsWinLen);
- std::vector<int16_t> denoisedAudioFrame(audioParamsWinLen);
+ /* Set up pre and post-processing. */
+ std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor =
+ std::make_shared<rnn::RNNoiseFeatureProcessor>();
+ std::shared_ptr<rnn::FrameFeatures> frameFeatures =
+ std::make_shared<rnn::FrameFeatures>();
+
+ RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures);
+
+ std::vector<int16_t> denoisedAudioFrame(audioFrameLen);
+ RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame,
+ featureProcessor, frameFeatures);
- std::vector<float> modelOutputFloat(outputTensor->bytes);
- rnn::FrameFeatures frameFeatures;
bool resetGRU = true;
while (audioDataSlider.HasNext()) {
const int16_t* inferenceWindow = audioDataSlider.Next();
- audioFrame = rnn::vec1D32F(inferenceWindow, inferenceWindow+audioParamsWinLen);
- featureProcessor.PreprocessFrame(audioFrame.data(), audioParamsWinLen, frameFeatures);
+ if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) {
+ printf_err("Pre-processing failed.");
+ return false;
+ }
/* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */
if (resetGRU){
@@ -148,53 +141,35 @@ namespace app {
model.CopyGruStates();
}
- QuantizeAndPopulateInput(frameFeatures.m_featuresVec,
- inputTensor->params.scale, inputTensor->params.zero_point,
- inputTensor);
-
/* Strings for presentation/logging. */
std::string str_inf{"Running inference... "};
/* Display message on the LCD - inference running. */
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1);
/* Run inference over this feature sliding window. */
- profiler.StartProfiling("Inference");
- bool success = model.RunInference();
- profiler.StopProfiling();
- resetGRU = false;
-
- if (!success) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
+ resetGRU = false;
- /* De-quantize main model output ready for post-processing. */
- const auto* outputData = tflite::GetTensorData<int8_t>(outputTensor);
- auto outputQuantParams = arm::app::GetTensorQuantParams(outputTensor);
-
- for (size_t i = 0; i < outputTensor->bytes; ++i) {
- modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset)
- * outputQuantParams.scale;
- }
-
- /* Round and cast the post-processed results for dumping to wav. */
- featureProcessor.PostProcessFrame(modelOutputFloat, frameFeatures, denoisedAudioFrameFloat);
- for (size_t i = 0; i < audioParamsWinLen; ++i) {
- denoisedAudioFrame[i] = static_cast<int16_t>(std::roundf(denoisedAudioFrameFloat[i]));
+ /* Carry out post-processing. */
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
+ return false;
}
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
if (memDumpMaxLen > 0) {
- /* Dump output tensors to memory. */
+ /* Dump final post processed output to memory. */
memDumpBytesWritten += DumpOutputDenoisedAudioFrame(
denoisedAudioFrame,
memDumpBaseAddr + memDumpBytesWritten,
@@ -209,6 +184,7 @@ namespace app {
valMemDumpBytesWritten, startDumpAddress);
}
+ /* Finish by dumping the footer. */
DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten);
info("All inferences for audio clip complete.\n");
@@ -216,15 +192,13 @@ namespace app {
IncrementAppCtxClipIdx(ctx);
std::string clearString{' '};
- hal_lcd_display_text(
- clearString.c_str(), clearString.size(),
+ hal_lcd_display_text(clearString.c_str(), clearString.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
std::string completeMsg{"Inference complete!"};
/* Display message on the LCD - inference complete. */
- hal_lcd_display_text(
- completeMsg.c_str(), completeMsg.size(),
+ hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(),
dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
@@ -233,7 +207,7 @@ namespace app {
}
size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize,
- uint8_t *memAddress, size_t memSize){
+ uint8_t* memAddress, size_t memSize){
if (memAddress == nullptr){
return 0;
@@ -284,7 +258,7 @@ namespace app {
return numBytesWritten;
}
- size_t DumpDenoisedAudioFooter(uint8_t *memAddress, size_t memSize){
+ size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){
if ((memAddress == nullptr) || (memSize < 4)) {
return 0;
}
@@ -294,8 +268,8 @@ namespace app {
return sizeof(int32_t);
}
- size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t> &audioFrame,
- uint8_t *memAddress, size_t memSize)
+ size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame,
+ uint8_t* memAddress, size_t memSize)
{
if (memAddress == nullptr) {
return 0;
@@ -324,7 +298,7 @@ namespace app {
const TfLiteTensor* tensor = model.GetOutputTensor(i);
const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(tensor);
+ DumpTensor(tensor);
#endif /* VERIFY_TEST_OUTPUT */
/* Ensure that we don't overflow the allowed limit. */
if (numBytesWritten + tensor->bytes <= memSize) {
@@ -360,20 +334,5 @@ namespace app {
ctx.Set<uint32_t>("clipIndex", curClipIdx);
}
- void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
- const float quantScale, const int quantOffset, TfLiteTensor* inputTensor)
- {
- const float minVal = std::numeric_limits<int8_t>::min();
- const float maxVal = std::numeric_limits<int8_t>::max();
-
- auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor);
-
- for (size_t i=0; i < inputFeatures.size(); ++i) {
- float quantValue = ((inputFeatures[i] / quantScale) + quantOffset);
- inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal));
- }
- }
-
-
} /* namespace app */
} /* namespace arm */
diff --git a/tests/use_case/ad/PostProcessTests.cc b/tests/use_case/ad/PostProcessTests.cc
deleted file mode 100644
index 62fa9e7..0000000
--- a/tests/use_case/ad/PostProcessTests.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "AdPostProcessing.hpp"
-#include <catch.hpp>
-
-TEST_CASE("Softmax_vector") {
-
- std::vector<float> testVec = {1, 2, 3, 4, 1, 2, 3};
- arm::app::Softmax(testVec);
- CHECK((testVec[0] - 0.024) == Approx(0.0).margin(0.001));
- CHECK((testVec[1] - 0.064) == Approx(0.0).margin(0.001));
- CHECK((testVec[2] - 0.175) == Approx(0.0).margin(0.001));
- CHECK((testVec[3] - 0.475) == Approx(0.0).margin(0.001));
- CHECK((testVec[4] - 0.024) == Approx(0.0).margin(0.001));
- CHECK((testVec[5] - 0.064) == Approx(0.0).margin(0.001));
- CHECK((testVec[6] - 0.175) == Approx(0.0).margin(0.001));
-}
-
-TEST_CASE("Output machine index") {
-
- auto index = arm::app::OutputIndexFromFileName("test_id_00.wav");
- CHECK(index == 0);
-
- auto index1 = arm::app::OutputIndexFromFileName("test_id_02.wav");
- CHECK(index1 == 1);
-
- auto index2 = arm::app::OutputIndexFromFileName("test_id_4.wav");
- CHECK(index2 == 2);
-
- auto index3 = arm::app::OutputIndexFromFileName("test_id_6.wav");
- CHECK(index3 == 3);
-
- auto index4 = arm::app::OutputIndexFromFileName("test_id_id_00.wav");
- CHECK(index4 == -1);
-
- auto index5 = arm::app::OutputIndexFromFileName("test_id_7.wav");
- CHECK(index5 == -1);
-} \ No newline at end of file
diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc
index 3ebdcf4..883c215 100644
--- a/tests/use_case/kws_asr/MfccTests.cc
+++ b/tests/use_case/kws_asr/MfccTests.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -93,13 +93,13 @@ const std::vector<float> testWavMfcc {
-22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072,
};
-arm::app::audio::MicroNetMFCC GetMFCCInstance() {
- const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq;
+arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() {
+ const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
const int frameLenMs = 40;
const int frameLenSamples = sampFreq * frameLenMs * 0.001;
const int numMfccFeats = 10;
- return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples);
+ return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples);
}
template <class T>
diff --git a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc
index 6fd7df3..e343b66 100644
--- a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc
+++ b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,15 +16,17 @@
*/
#include "Wav2LetterPostprocess.hpp"
#include "Wav2LetterModel.hpp"
+#include "ClassificationResult.hpp"
#include <algorithm>
#include <catch.hpp>
#include <limits>
template <typename T>
-static TfLiteTensor GetTestTensor(std::vector <int>& shape,
- T initVal,
- std::vector<T>& vectorBuf)
+static TfLiteTensor GetTestTensor(
+ std::vector<int>& shape,
+ T initVal,
+ std::vector<T>& vectorBuf)
{
REQUIRE(0 != shape.size());
@@ -38,91 +40,112 @@ static TfLiteTensor GetTestTensor(std::vector <int>& shape,
vectorBuf = std::vector<T>(sizeInBytes, initVal);
TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data());
return tflite::testing::CreateQuantizedTensor(
- vectorBuf.data(), dims,
- 1, 0, "test-tensor");
+ vectorBuf.data(), dims,
+ 1, 0, "test-tensor");
}
TEST_CASE("Checking return value")
{
SECTION("Mismatched post processing parameters and tensor size")
{
- const uint32_t ctxLen = 5;
- const uint32_t innerLen = 3;
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
-
+ const uint32_t outputCtxLen = 5;
+ arm::app::AsrClassifier classifier;
+ arm::app::Wav2LetterModel model;
+ model.Init();
+ std::vector<std::string> dummyLabels = {"a", "b", "$"};
+ const uint32_t blankTokenIdx = 2;
+ std::vector<arm::app::ClassificationResult> dummyResult;
std::vector <int> tensorShape = {1, 1, 1, 13};
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
- REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ tensorShape, 100, tensorVec);
+
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
+
+ REQUIRE(!post.DoPostProcess());
}
SECTION("Post processing succeeds")
{
- const uint32_t ctxLen = 5;
- const uint32_t innerLen = 3;
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
-
- std::vector <int> tensorShape = {1, 1, 13, 1};
- std::vector <int8_t> tensorVec;
+ const uint32_t outputCtxLen = 5;
+ arm::app::AsrClassifier classifier;
+ arm::app::Wav2LetterModel model;
+ model.Init();
+ std::vector<std::string> dummyLabels = {"a", "b", "$"};
+ const uint32_t blankTokenIdx = 2;
+ std::vector<arm::app::ClassificationResult> dummyResult;
+ std::vector<int> tensorShape = {1, 1, 13, 1};
+ std::vector<int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ tensorShape, 100, tensorVec);
+
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
- std::vector <int8_t> originalVec = tensorVec;
+ std::vector<int8_t> originalVec = tensorVec;
/* This step should not erase anything. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ REQUIRE(post.DoPostProcess());
}
}
+
TEST_CASE("Postprocessing - erasing required elements")
{
- constexpr uint32_t ctxLen = 5;
+ constexpr uint32_t outputCtxLen = 5;
constexpr uint32_t innerLen = 3;
- constexpr uint32_t nRows = 2*ctxLen + innerLen;
+ constexpr uint32_t nRows = 2*outputCtxLen + innerLen;
constexpr uint32_t nCols = 10;
constexpr uint32_t blankTokenIdx = nCols - 1;
- std::vector <int> tensorShape = {1, 1, nRows, nCols};
+ std::vector<int> tensorShape = {1, 1, nRows, nCols};
+ arm::app::AsrClassifier classifier;
+ arm::app::Wav2LetterModel model;
+ model.Init();
+ std::vector<std::string> dummyLabels = {"a", "b", "$"};
+ std::vector<arm::app::ClassificationResult> dummyResult;
SECTION("First and last iteration")
{
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
- std::vector <int8_t> tensorVec;
- TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ std::vector<int8_t> tensorVec;
+ TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
- std::vector <int8_t> originalVec = tensorVec;
+ std::vector<int8_t>originalVec = tensorVec;
/* This step should not erase anything. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
+ post.m_lastIteration = true;
+ REQUIRE(post.DoPostProcess());
REQUIRE(originalVec == tensorVec);
}
SECTION("Right context erase")
{
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ tensorShape, 100, tensorVec);
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
- std::vector <int8_t> originalVec = tensorVec;
+ std::vector<int8_t> originalVec = tensorVec;
/* This step should erase the right context only. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ post.m_lastIteration = false;
+ REQUIRE(post.DoPostProcess());
REQUIRE(originalVec != tensorVec);
/* The last ctxLen * 10 elements should be gone. */
- for (size_t i = 0; i < ctxLen; ++i) {
+ for (size_t i = 0; i < outputCtxLen; ++i) {
for (size_t j = 0; j < nCols; ++j) {
- /* Check right context elements are zeroed. */
+ /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */
if (j == blankTokenIdx) {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
} else {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
}
/* Check left context is preserved. */
@@ -131,45 +154,47 @@ TEST_CASE("Postprocessing - erasing required elements")
}
/* Check inner elements are preserved. */
- for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
+ for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
CHECK(tensorVec[i] == originalVec[i]);
}
}
SECTION("Left and right context erase")
{
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
std::vector <int8_t> tensorVec;
- TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
+ TfLiteTensor tensor = GetTestTensor<int8_t>(
+ tensorShape, 100, tensorVec);
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
std::vector <int8_t> originalVec = tensorVec;
/* This step should erase right context. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ post.m_lastIteration = false;
+ REQUIRE(post.DoPostProcess());
/* Calling it the second time should erase the left context. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ REQUIRE(post.DoPostProcess());
REQUIRE(originalVec != tensorVec);
/* The first and last ctxLen * 10 elements should be gone. */
- for (size_t i = 0; i < ctxLen; ++i) {
+ for (size_t i = 0; i < outputCtxLen; ++i) {
for (size_t j = 0; j < nCols; ++j) {
/* Check left and right context elements are zeroed. */
if (j == blankTokenIdx) {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 1);
- CHECK(tensorVec[i * nCols + j] == 1);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
+ CHECK(tensorVec[i*nCols + j] == 1);
} else {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 0);
- CHECK(tensorVec[i * nCols + j] == 0);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
+ CHECK(tensorVec[i*nCols + j] == 0);
}
}
}
/* Check inner elements are preserved. */
- for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
+ for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
/* Check left context is preserved. */
CHECK(tensorVec[i] == originalVec[i]);
}
@@ -177,18 +202,21 @@ TEST_CASE("Postprocessing - erasing required elements")
SECTION("Try left context erase")
{
- /* Should not be able to erase the left context if it is the first iteration. */
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ tensorShape, 100, tensorVec);
+
+ /* Should not be able to erase the left context if it is the first iteration. */
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
std::vector <int8_t> originalVec = tensorVec;
/* Calling it the second time should erase the left context. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
+ post.m_lastIteration = true;
+ REQUIRE(post.DoPostProcess());
+
REQUIRE(originalVec == tensorVec);
}
-} \ No newline at end of file
+}
diff --git a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc
index 26ddb24..372152d 100644
--- a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc
+++ b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,64 +16,54 @@
*/
#include "Wav2LetterPreprocess.hpp"
-#include <algorithm>
-#include <catch.hpp>
#include <limits>
+#include <catch.hpp>
constexpr uint32_t numMfccFeatures = 13;
constexpr uint32_t numMfccVectors = 10;
/* Test vector output: generated using test-asr-preprocessing.py. */
-int8_t expectedResult[numMfccVectors][numMfccFeatures*3] = {
- /* Feature vec 0. */
- -32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, /* Delta 2. */
-
- /* Feature vec 1. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 2. */
- -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 3. */
- -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 4 : this should have valid delta 1 and delta 2. */
- -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
- -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5,
- -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15,
-
- /* Feature vec 5 : this should have valid delta 1 and delta 2. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12,
- -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13,
- -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6,
-
- /* Feature vec 6. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 7. */
- -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 8. */
- -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 9. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10
+int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = {
+ /* Feature vec 0. */
+ {-32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, /* Delta 2. */
+ /* Feature vec 1. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 2. */
+ {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 3. */
+ {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 4 : this should have valid delta 1 and delta 2. */
+ {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+ -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5,
+ -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15},
+ /* Feature vec 5 : this should have valid delta 1 and delta 2. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+ -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13,
+ -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6},
+ /* Feature vec 6. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 7. */
+ {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 8. */
+ {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 9. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}
};
void PopulateTestWavVector(std::vector<int16_t>& vec)
@@ -97,17 +87,17 @@ void PopulateTestWavVector(std::vector<int16_t>& vec)
TEST_CASE("Preprocessing calculation INT8")
{
-
/* Constants. */
- const uint32_t windowLen = 512;
- const uint32_t windowStride = 160;
- int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors};
- const float quantScale = 0.1410219967365265;
- const int quantOffset = -11;
+ const uint32_t mfccWindowLen = 512;
+ const uint32_t mfccWindowStride = 160;
+ int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors};
+ const float quantScale = 0.1410219967365265;
+ const int quantOffset = -11;
/* Test wav memory. */
- std::vector <int16_t> testWav((windowStride * numMfccVectors) +
- (windowLen - windowStride));
+ std::vector<int16_t> testWav((mfccWindowStride * numMfccVectors) +
+ (mfccWindowLen - mfccWindowStride)
+ );
/* Populate with dummy input. */
PopulateTestWavVector(testWav);
@@ -117,20 +107,20 @@ TEST_CASE("Preprocessing calculation INT8")
/* Initialise dimensions and the test tensor. */
TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray);
- TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor(
- tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput");
+ TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor(
+ tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput");
/* Initialise pre-processing module. */
- arm::app::audio::asr::Preprocess prep{
- numMfccFeatures, windowLen, windowStride, numMfccVectors};
+ arm::app::AsrPreProcess prep{&inputTensor,
+ numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride};
/* Invoke pre-processing. */
- REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor));
+ REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size()));
/* Wrap the tensor with a std::vector for ease. */
- int8_t * tensorData = tflite::GetTensorData<int8_t>(&tensor);
+ auto* tensorData = tflite::GetTensorData<int8_t>(&inputTensor);
std::vector <int8_t> vecResults =
- std::vector<int8_t>(tensorData, tensorData + tensor.bytes);
+ std::vector<int8_t>(tensorData, tensorData + inputTensor.bytes);
/* Check sizes. */
REQUIRE(vecResults.size() == sizeof(expectedResult));
diff --git a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp
index e28a6da..ca5aab1 100644
--- a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp
+++ b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "RNNoiseProcess.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
#include <catch.hpp>
#include <limits>
@@ -208,7 +208,7 @@ TEST_CASE("RNNoise preprocessing calculation test", "[RNNoise]")
{
SECTION("FP32")
{
- arm::app::rnn::RNNoiseProcess rnnoiseProcessor;
+ arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor;
arm::app::rnn::FrameFeatures features;
rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), features);
@@ -223,7 +223,7 @@ TEST_CASE("RNNoise preprocessing calculation test", "[RNNoise]")
TEST_CASE("RNNoise postprocessing test", "[RNNoise]")
{
- arm::app::rnn::RNNoiseProcess rnnoiseProcessor;
+ arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor;
arm::app::rnn::FrameFeatures p;
rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), p);
std::vector<float> denoised(testWav0.size());