From 4e002791bc6781b549c6951cfe44f918289d7e82 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Wed, 4 May 2022 09:45:02 +0100 Subject: MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's Signed-off-by: Richard Burton Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37 --- source/use_case/kws_asr/include/KwsProcessing.hpp | 138 ++++++ .../use_case/kws_asr/include/MicroNetKwsMfcc.hpp | 10 +- .../use_case/kws_asr/include/Wav2LetterModel.hpp | 12 +- .../kws_asr/include/Wav2LetterPostprocess.hpp | 117 ++--- .../kws_asr/include/Wav2LetterPreprocess.hpp | 138 +++--- source/use_case/kws_asr/src/KwsProcessing.cc | 212 +++++++++ source/use_case/kws_asr/src/MainLoop.cc | 125 ++---- source/use_case/kws_asr/src/UseCaseHandler.cc | 492 +++++++-------------- .../use_case/kws_asr/src/Wav2LetterPostprocess.cc | 146 ++++-- .../use_case/kws_asr/src/Wav2LetterPreprocess.cc | 116 ++--- source/use_case/kws_asr/usecase.cmake | 4 +- 11 files changed, 820 insertions(+), 690 deletions(-) create mode 100644 source/use_case/kws_asr/include/KwsProcessing.hpp create mode 100644 source/use_case/kws_asr/src/KwsProcessing.cc (limited to 'source/use_case/kws_asr') diff --git a/source/use_case/kws_asr/include/KwsProcessing.hpp b/source/use_case/kws_asr/include/KwsProcessing.hpp new file mode 100644 index 0000000..d3de3b3 --- /dev/null +++ b/source/use_case/kws_asr/include/KwsProcessing.hpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_PROCESSING_HPP +#define KWS_PROCESSING_HPP + +#include +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "Classifier.hpp" +#include "MicroNetKwsMfcc.hpp" + +#include + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Keyword Spotting use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class KwsPreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numFeatures How many MFCC features to use. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when + * sliding a window through the audio sample. + * @param[in] mfccFrameStride Number of audio samples between consecutive windows. + **/ + explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames, + int mfccFrameLength, int mfccFrameStride); + + /** + * @brief Should perform pre-processing of 'raw' input audio data and load it into + * TFLite Micro input tensors ready for inference. + * @param[in] input Pointer to the data that pre-processing will work on. + * @param[in] inputSize Size of the input data. + * @return true if successful, false otherwise. + **/ + bool DoPreProcess(const void* input, size_t inputSize) override; + + size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */ + size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */ + size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */ + + private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ + const int m_mfccFrameLength; + const int m_mfccFrameStride; + const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */ + + audio::MicroNetKwsMFCC m_mfcc; + audio::SlidingWindow m_mfccSlidingWindow; + size_t m_numMfccVectorsInAudioStride; + size_t m_numReusedMfccVectors; + std::function&, int, bool, size_t)> m_mfccFeatureCalculator; + + /** + * @brief Returns a function to perform feature calculation and populates input tensor data with + * MFCC data. + * + * Input tensor data type check is performed to choose correct MFCC feature data type. + * If tensor has an integer data type then original features are quantised. + * + * Warning: MFCC calculator provided as input must have the same life scope as returned function. + * + * @param[in] mfcc MFCC feature calculator. + * @param[in,out] inputTensor Input tensor pointer to store calculated features. + * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). + * @return Function to be called providing audio sample and sliding window index. + */ + std::function&, int, bool, size_t)> + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, + TfLiteTensor* inputTensor, + size_t cacheSize); + + template + std::function&, size_t, bool, size_t)> + FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute); + }; + + /** + * @brief Post-processing class for Keyword Spotting use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class KwsPostProcess : public BasePostProcess { + + private: + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + Classifier& m_kwsClassifier; /* KWS Classifier object. */ + const std::vector& m_labels; /* KWS Labels. */ + std::vector& m_results; /* Results vector for a single inference. */ + + public: + /** + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. + **/ + KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, + const std::vector& labels, + std::vector& results); + + /** + * @brief Should perform post-processing of the result of inference then + * populate KWS result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_PROCESSING_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp index 43bd390..af6ba5f 100644 --- a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp +++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,7 +24,7 @@ namespace app { namespace audio { /* Class to provide MicroNet specific MFCC calculation requirements. */ - class MicroNetMFCC : public MFCC { + class MicroNetKwsMFCC : public MFCC { public: static constexpr uint32_t ms_defaultSamplingFreq = 16000; @@ -34,14 +34,14 @@ namespace audio { static constexpr bool ms_defaultUseHtkMethod = true; - explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen) + explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen) : MFCC(MfccParams( ms_defaultSamplingFreq, ms_defaultNumFbankBins, ms_defaultMelLoFreq, ms_defaultMelHiFreq, numFeats, frameLen, ms_defaultUseHtkMethod)) {} - MicroNetMFCC() = delete; - ~MicroNetMFCC() = default; + MicroNetKwsMFCC() = delete; + ~MicroNetKwsMFCC() = default; }; } /* namespace audio */ diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp index 7c327b3..0e1adc5 100644 --- a/source/use_case/kws_asr/include/Wav2LetterModel.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,14 +34,18 @@ namespace arm { namespace app { class Wav2LetterModel : public Model { - + public: /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 1; - static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; + /* Model specific constants. */ + static constexpr uint32_t ms_blankTokenIdx = 28; + static constexpr uint32_t ms_numMfccFeatures = 13; + protected: /** @brief Gets the reference to op resolver interface class. */ const tflite::MicroOpResolver& GetOpResolver() override; diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp index 029a641..d1bc9a2 100644 --- a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,88 +14,95 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_WAV2LET_POSTPROC_HPP -#define KWS_ASR_WAV2LET_POSTPROC_HPP +#ifndef KWS_ASR_WAV2LETTER_POSTPROCESS_HPP +#define KWS_ASR_WAV2LETTER_POSTPROCESS_HPP -#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers */ +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "BaseProcessing.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" +#include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /** * @brief Helper class to manage tensor post-processing for "wav2letter" * output. */ - class Postprocess { + class AsrPostProcess : public BasePostProcess { public: + bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ + /** - * @brief Constructor - * @param[in] contextLen Left and right context length for - * output tensor. - * @param[in] innerLen This is the length of the section - * between left and right context. - * @param[in] blankTokenIdx Blank token index. + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in] outputContextLen Left/right context length for output tensor. + * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. + * @param[in] reductionAxis The axis that the logits of each time step is on. **/ - Postprocess(uint32_t contextLen, - uint32_t innerLen, - uint32_t blankTokenIdx); - - Postprocess() = delete; - ~Postprocess() = default; + AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); /** - * @brief Erases the required part of the tensor based - * on context lengths set up during initialisation - * @param[in] tensor Pointer to the tensor - * @param[in] axisIdx Index of the axis on which erase is - * performed. - * @param[in] lastIteration Flag to signal is this is the - * last iteration in which case - * the right context is preserved. - * @return true if successful, false otherwise. - */ - bool Invoke(TfLiteTensor* tensor, - uint32_t axisIdx, - bool lastIteration = false); + * @brief Should perform post-processing of the result of inference then + * populate ASR result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + /** @brief Gets the output inner length for post-processing. */ + static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen); + + /** @brief Gets the output context length (left/right) for post-processing. */ + static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen); + + /** @brief Gets the number of feature vectors to be computed. */ + static uint32_t GetNumFeatureVectors(const Model& model); private: - uint32_t m_contextLen; /* Lengths of left and right contexts. */ - uint32_t m_innerLen; /* Length of inner context. */ - uint32_t m_totalLen; /* Total length of the required axis. */ - uint32_t m_countIterations; /* Current number of iterations. */ - uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + AsrClassifier& m_classifier; /* ASR Classifier object. */ + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + const std::vector& m_labels; /* ASR Labels. */ + asr::ResultVec & m_results; /* Results vector for a single inference. */ + uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */ + uint32_t m_outputInnerLen; /* Length of output inner context. */ + uint32_t m_totalLen; /* Total length of the required axis. */ + uint32_t m_countIterations; /* Current number of iterations. */ + uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */ + /** - * @brief Checks if the tensor and axis index are valid - * inputs to the object - based on how it has been - * initialised. - * @return true if valid, false otherwise. + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been initialised. + * @return true if valid, false otherwise. */ bool IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const; + uint32_t axisIdx) const; /** - * @brief Gets the tensor data element size in bytes based - * on the tensor type. - * @return Size in bytes, 0 if not supported. + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. */ - uint32_t GetTensorElementSize(TfLiteTensor* tensor); + static uint32_t GetTensorElementSize(TfLiteTensor* tensor); /** - * @brief Erases sections from the data assuming row-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. */ bool EraseSectionsRowWise(uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration); - + uint32_t strideSzBytes, + bool lastIteration); }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_WAV2LET_POSTPROC_HPP */ \ No newline at end of file +#endif /* KWS_ASR_WAV2LETTER_POSTPROCESS_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp index 3609c49..1224c23 100644 --- a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp +++ b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,56 +14,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_WAV2LET_PREPROC_HPP -#define KWS_ASR_WAV2LET_PREPROC_HPP +#ifndef KWS_ASR_WAV2LETTER_PREPROCESS_HPP +#define KWS_ASR_WAV2LETTER_PREPROCESS_HPP #include "Wav2LetterModel.hpp" #include "Wav2LetterMfcc.hpp" #include "AudioUtils.hpp" #include "DataStructures.hpp" +#include "BaseProcessing.hpp" #include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /* Class to facilitate pre-processing calculation for Wav2Letter model * for ASR. */ - using AudioWindow = SlidingWindow ; + using AudioWindow = audio::SlidingWindow; - class Preprocess { + class AsrPreProcess : public BasePreProcess { public: /** - * @brief Constructor - * @param[in] numMfccFeatures Number of MFCC features per window. - * @param[in] windowLen Number of elements in a window. - * @param[in] windowStride Stride (in number of elements) for - * moving the window. - * @param[in] numMfccVectors Number of MFCC vectors per window. - */ - Preprocess( - uint32_t numMfccFeatures, - uint32_t windowLen, - uint32_t windowStride, - uint32_t numMfccVectors); - Preprocess() = delete; - ~Preprocess() = default; + * @brief Constructor. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. + * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. + */ + AsrPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t numFeatureFrames, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); /** * @brief Calculates the features required from audio data. This * includes MFCC, first and second order deltas, * normalisation and finally, quantisation. The tensor is - * populated with feature from a given window placed along + * populated with features from a given window placed along * in a single row. * @param[in] audioData Pointer to the first element of audio data. * @param[in] audioDataLen Number of elements in the audio data. - * @param[in] tensor Tensor to be populated. * @return true if successful, false in case of error. */ - bool Invoke(const int16_t * audioData, - uint32_t audioDataLen, - TfLiteTensor * tensor); + bool DoPreProcess(const void* audioData, size_t audioDataLen) override; protected: /** @@ -73,49 +68,32 @@ namespace asr { * @param[in] mfcc MFCC buffers. * @param[out] delta1 Result of the first diff computation. * @param[out] delta2 Result of the second diff computation. - * - * @return true if successful, false otherwise. + * @return true if successful, false otherwise. */ static bool ComputeDeltas(Array2d& mfcc, Array2d& delta1, Array2d& delta2); /** - * @brief Given a 2D vector of floats, computes the mean. - * @param[in] vec Vector of vector of floats. - * @return Mean value. - */ - static float GetMean(Array2d& vec); - - /** - * @brief Given a 2D vector of floats, computes the stddev. - * @param[in] vec Vector of vector of floats. - * @param[in] mean Mean value of the vector passed in. - * @return stddev value. - */ - static float GetStdDev(Array2d& vec, - const float mean); - - /** - * @brief Given a 2D vector of floats, normalises it using - * the mean and the stddev + * @brief Given a 2D vector of floats, rescale it to have mean of 0 and + * standard deviation of 1. * @param[in,out] vec Vector of vector of floats. */ - static void NormaliseVec(Array2d& vec); + static void StandardizeVecF32(Array2d& vec); /** - * @brief Normalises the MFCC and delta buffers. + * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1. */ - void Normalise(); + void Standarize(); /** * @brief Given the quantisation and data type limits, computes * the quantised values of a floating point input data. - * @param[in] elem Element to be quantised. - * @param[in] quantScale Scale. - * @param[in] quantOffset Offset. - * @param[in] minVal Numerical limit - minimum. - * @param[in] maxVal Numerical limit - maximum. + * @param[in] elem Element to be quantised. + * @param[in] quantScale Scale. + * @param[in] quantOffset Offset. + * @param[in] minVal Numerical limit - minimum. + * @param[in] maxVal Numerical limit - maximum. * @return Floating point quantised value. */ static float GetQuantElem( @@ -133,44 +111,43 @@ namespace asr { * this being the convolution speed up (as we can use * contiguous memory). The output, however, requires the * time axis to be in column major arrangement. - * @param[in] outputBuf Pointer to the output buffer. - * @param[in] outputBufSz Output buffer's size. - * @param[in] quantScale Quantisation scale. - * @param[in] quantOffset Quantisation offset. + * @param[in] outputBuf Pointer to the output buffer. + * @param[in] outputBufSz Output buffer's size. + * @param[in] quantScale Quantisation scale. + * @param[in] quantOffset Quantisation offset. */ template bool Quantise( - T * outputBuf, + T* outputBuf, const uint32_t outputBufSz, const float quantScale, const int quantOffset) { - /* Check the output size will for everything. */ + /* Check the output size will fit everything. */ if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) { printf_err("Tensor size too small for features\n"); return false; } /* Populate. */ - T * outputBufMfcc = outputBuf; - T * outputBufD1 = outputBuf + this->m_numMfccFeats; - T * outputBufD2 = outputBufD1 + this->m_numMfccFeats; + T* outputBufMfcc = outputBuf; + T* outputBufD1 = outputBuf + this->m_numMfccFeats; + T* outputBufD2 = outputBufD1 + this->m_numMfccFeats; const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ const float minVal = std::numeric_limits::min(); const float maxVal = std::numeric_limits::max(); - /* We need to do a transpose while copying and concatenating - * the tensor. */ - for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) { + /* Need to transpose while copying and concatenating the tensor. */ + for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast(this->GetQuantElem( + *outputBufMfcc++ = static_cast(AsrPreProcess::GetQuantElem( this->m_mfccBuf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast(this->GetQuantElem( + *outputBufD1++ = static_cast(AsrPreProcess::GetQuantElem( this->m_delta1Buf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast(this->GetQuantElem( + *outputBufD2++ = static_cast(AsrPreProcess::GetQuantElem( this->m_delta2Buf(i, j), quantScale, quantOffset, minVal, maxVal)); } @@ -183,24 +160,23 @@ namespace asr { } private: - Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + TfLiteTensor* m_inputTensor; /* Model input tensor. */ /* Actual buffers to be populated. */ - Array2d m_mfccBuf; /* Contiguous buffer 1D: MFCC */ - Array2d m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ - Array2d m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + Array2d m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ - uint32_t m_windowLen; /* Window length for MFCC. */ - uint32_t m_windowStride; /* Window stride len for MFCC. */ - uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ - uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */ - AudioWindow m_window; /* Sliding window. */ + uint32_t m_mfccWindowLen; /* Window length for MFCC. */ + uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */ + uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */ + AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */ }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_WAV2LET_PREPROC_HPP */ \ No newline at end of file +#endif /* KWS_ASR_WAV2LETTER_PREPROCESS_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/KwsProcessing.cc b/source/use_case/kws_asr/src/KwsProcessing.cc new file mode 100644 index 0000000..328709d --- /dev/null +++ b/source/use_case/kws_asr/src/KwsProcessing.cc @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "KwsProcessing.hpp" +#include "ImageUtils.hpp" +#include "log_macros.h" +#include "MicroNetKwsModel.hpp" + +namespace arm { +namespace app { + + KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames, + int mfccFrameLength, int mfccFrameStride + ): + m_inputTensor{inputTensor}, + m_mfccFrameLength{mfccFrameLength}, + m_mfccFrameStride{mfccFrameStride}, + m_numMfccFrames{numMfccFrames}, + m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)} + { + this->m_mfcc.Init(); + + /* Deduce the data length required for 1 inference from the network parameters. */ + this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride + + (this->m_mfccFrameLength - this->m_mfccFrameStride); + + /* Creating an MFCC feature sliding window for the data required for 1 inference. */ + this->m_mfccSlidingWindow = audio::SlidingWindow(nullptr, this->m_audioDataWindowSize, + this->m_mfccFrameLength, this->m_mfccFrameStride); + + /* For longer audio clips we choose to move by half the audio window size + * => for a 1 second window size there is an overlap of 0.5 seconds. */ + this->m_audioDataStride = this->m_audioDataWindowSize / 2; + + /* To have the previously calculated features re-usable, stride must be multiple + * of MFCC features window stride. Reduce stride through audio if needed. */ + if (0 != this->m_audioDataStride % this->m_mfccFrameStride) { + this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride; + } + + this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride; + + /* Calculate number of the feature vectors in the window overlap region. + * These feature vectors will be reused.*/ + this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1 + - this->m_numMfccVectorsInAudioStride; + + /* Construct feature calculation function. */ + this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor, + this->m_numReusedMfccVectors); + + if (!this->m_mfccFeatureCalculator) { + printf_err("Feature calculator not initialized."); + } + } + + bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize) + { + UNUSED(inputSize); + if (data == nullptr) { + printf_err("Data pointer is null"); + } + + /* Set the features sliding window to the new address. */ + auto input = static_cast(data); + this->m_mfccSlidingWindow.Reset(input); + + /* Cache is only usable if we have more than 1 inference in an audio clip. */ + bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0; + + /* Use a sliding window to calculate MFCC features frame by frame. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); + + std::vector mfccFrameAudioData = std::vector(mfccWindow, + mfccWindow + this->m_mfccFrameLength); + + /* Compute features for this window and write them to input tensor. */ + this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(), + useCache, this->m_numMfccVectorsInAudioStride); + } + + debug("Input tensor populated \n"); + + return true; + } + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T Feature vector type. + * @param[in] inputTensor Model input tensor pointer. + * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap. + * @param[in] compute Features calculator function. + * @return Lambda function to compute features. + */ + template + std::function&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute) + { + /* Feature cache to be captured by lambda function. */ + static std::vector> featureCache = std::vector>(cacheSize); + + return [=](std::vector& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex) + { + T* tensorData = tflite::GetTensorData(inputTensor); + std::vector features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size(); + auto sizeBytes = sizeof(T) * size; + std::memcpy(tensorData + (index * size), features.data(), sizeBytes); + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex) { + featureCache[index - featuresOverlapIndex] = std::move(features); + } + }; + } + + template std::function&, size_t , bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector&)> compute); + + template std::function&, size_t, bool, size_t)> + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function(std::vector&)> compute); + + + std::function&, int, bool, size_t)> + KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + { + std::function&, size_t, bool, size_t)> mfccFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + auto *quantParams = (TfLiteAffineQuantization *) quant.params; + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + mfccFeatureCalc = this->FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + } else { + mfccFeatureCalc = this->FeatureCalc(inputTensor, cacheSize, + [&mfcc](std::vector& audioDataWindow) { + return mfcc.MfccCompute(audioDataWindow); } + ); + } + return mfccFeatureCalc; + } + + KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, + const std::vector& labels, + std::vector& results) + :m_outputTensor{outputTensor}, + m_kwsClassifier{classifier}, + m_labels{labels}, + m_results{results} + {} + + bool KwsPostProcess::DoPostProcess() + { + return this->m_kwsClassifier.GetClassificationResults( + this->m_outputTensor, this->m_results, + this->m_labels, 1, true); + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index 5c1d0e0..f1d97a0 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "InputFiles.hpp" /* For input images. */ #include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ @@ -24,8 +23,6 @@ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ -#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */ -#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */ #include "log_macros.h" using KwsClassifier = arm::app::Classifier; @@ -53,19 +50,8 @@ static void DisplayMenu() fflush(stdout); } -/** @brief Gets the number of MFCC features for a single window. */ -static uint32_t GetNumMfccFeatures(const arm::app::Model& model); - -/** @brief Gets the number of MFCC feature vectors to be computed. */ -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); - -/** @brief Gets the output context length (left and right) for post-processing. */ -static uint32_t GetOutputContextLen(const arm::app::Model& model, - uint32_t inputCtxLen); - -/** @brief Gets the output inner length for post-processing. */ -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - uint32_t outputCtxLen); +/** @brief Verify input and output tensor are of certain min dimensions. */ +static bool VerifyTensorDimensions(const arm::app::Model& model); void main_loop() { @@ -84,61 +70,46 @@ void main_loop() if (!asrModel.Init(kwsModel.GetAllocator())) { printf_err("Failed to initialise ASR model\n"); return; + } else if (!VerifyTensorDimensions(asrModel)) { + printf_err("Model's input or output dimension verification failed\n"); + return; } - /* Initialise ASR pre-processing. */ - arm::app::audio::asr::Preprocess prep( - GetNumMfccFeatures(asrModel), - arm::app::asr::g_FrameLength, - arm::app::asr::g_FrameStride, - GetNumMfccFeatureVectors(asrModel)); - - /* Initialise ASR post-processing. */ - const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen); - const uint32_t blankTokenIdx = 28; - arm::app::audio::asr::Postprocess postp( - outputCtxLen, - GetOutputInnerLen(asrModel, outputCtxLen), - blankTokenIdx); - /* Instantiate application context. */ arm::app::ApplicationContext caseContext; arm::app::Profiler profiler{"kws_asr"}; caseContext.Set("profiler", profiler); - caseContext.Set("kwsmodel", kwsModel); - caseContext.Set("asrmodel", asrModel); + caseContext.Set("kwsModel", kwsModel); + caseContext.Set("asrModel", asrModel); caseContext.Set("clipIndex", 0); caseContext.Set("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */ - caseContext.Set("kwsframeLength", arm::app::kws::g_FrameLength); - caseContext.Set("kwsframeStride", arm::app::kws::g_FrameStride); - caseContext.Set("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set("kwsFrameLength", arm::app::kws::g_FrameLength); + caseContext.Set("kwsFrameStride", arm::app::kws::g_FrameStride); + caseContext.Set("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ caseContext.Set("kwsNumMfcc", arm::app::kws::g_NumMfcc); caseContext.Set("kwsNumAudioWins", arm::app::kws::g_NumAudioWins); - caseContext.Set("asrframeLength", arm::app::asr::g_FrameLength); - caseContext.Set("asrframeStride", arm::app::asr::g_FrameStride); - caseContext.Set("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set("asrFrameLength", arm::app::asr::g_FrameLength); + caseContext.Set("asrFrameStride", arm::app::asr::g_FrameStride); + caseContext.Set("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ KwsClassifier kwsClassifier; /* Classifier wrapper object. */ arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */ - caseContext.Set("kwsclassifier", kwsClassifier); - caseContext.Set("asrclassifier", asrClassifier); - - caseContext.Set("preprocess", prep); - caseContext.Set("postprocess", postp); + caseContext.Set("kwsClassifier", kwsClassifier); + caseContext.Set("asrClassifier", asrClassifier); std::vector asrLabels; arm::app::asr::GetLabelsVector(asrLabels); std::vector kwsLabels; arm::app::kws::GetLabelsVector(kwsLabels); - caseContext.Set&>("asrlabels", asrLabels); - caseContext.Set&>("kwslabels", kwsLabels); + caseContext.Set&>("asrLabels", asrLabels); + caseContext.Set&>("kwsLabels", kwsLabels); /* KWS keyword that triggers ASR and associated checks */ - std::string triggerKeyword = std::string("yes"); + std::string triggerKeyword = std::string("no"); if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) { - caseContext.Set("triggerkeyword", triggerKeyword); + caseContext.Set("triggerKeyword", triggerKeyword); } else { printf_err("Selected trigger keyword not found in labels file\n"); @@ -196,50 +167,26 @@ void main_loop() info("Main loop terminated.\n"); } -static uint32_t GetNumMfccFeatures(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; - if (0 != inputCols % 3) { - printf_err("Number of input columns is not a multiple of 3\n"); - } - return std::max(inputCols/3, 0); -} - -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) +static bool VerifyTensorDimensions(const arm::app::Model& model) { + /* Populate tensor related parameters. */ TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - return std::max(inputRows, 0); -} - -static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) -{ - const uint32_t inputRows = GetNumMfccFeatureVectors(model); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - - /* Check to make sure that the input tensor supports the above context and inner lengths. */ - if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { - printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", - inputCtxLen); - return 0; + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; } TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - - const float tensorColRatio = static_cast(inputRows)/ - static_cast(outputRows); - - return std::round(static_cast(inputCtxLen)/tensorColRatio); -} + if (!outputTensor->dims) { + printf_err("Invalid output tensor dims\n"); + return false; + } else if (outputTensor->dims->size < 3) { + printf_err("Output tensor dimension should be >= 3\n"); + return false; + } -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - const uint32_t outputCtxLen) -{ - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - return (outputRows - (2 * outputCtxLen)); + return true; } diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 1e1a400..01aefae 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -28,6 +28,7 @@ #include "Wav2LetterMfcc.hpp" #include "Wav2LetterPreprocess.hpp" #include "Wav2LetterPostprocess.hpp" +#include "KwsProcessing.hpp" #include "AsrResult.hpp" #include "AsrClassifier.hpp" #include "OutputDecode.hpp" @@ -39,11 +40,6 @@ using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { - enum AsrOutputReductionAxis { - AxisRow = 1, - AxisCol = 2 - }; - struct KWSOutput { bool executionSuccess = false; const int16_t* asrAudioStart = nullptr; @@ -51,73 +47,53 @@ namespace app { }; /** - * @brief Presents kws inference results using the data presentation - * object. - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents KWS inference results. + * @param[in] results Vector of KWS classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector& results); + static bool PresentInferenceResult(std::vector& results); /** - * @brief Presents asr inference results using the data presentation - * object. - * @param[in] platform reference to the hal platform object - * @param[in] results vector of classification results to be displayed - * @return true if successful, false otherwise + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(std::vector& results); + static bool PresentInferenceResult(std::vector& results); /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MFCC data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] mfcc MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). - * - * @return function function to be called providing audio sample and sliding window index. + * @brief Performs the KWS pipeline. + * @param[in,out] ctx pointer to the application context object + * @return struct containing pointer to audio data where ASR should begin + * and how much data to process. **/ - static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, - TfLiteTensor* inputTensor, - size_t cacheSize); + static KWSOutput doKws(ApplicationContext& ctx) + { + auto& profiler = ctx.Get("profiler"); + auto& kwsModel = ctx.Get("kwsModel"); + const auto kwsMfccFrameLength = ctx.Get("kwsFrameLength"); + const auto kwsMfccFrameStride = ctx.Get("kwsFrameStride"); + const auto kwsScoreThreshold = ctx.Get("kwsScoreThreshold"); + + auto currentIndex = ctx.Get("clipIndex"); - /** - * @brief Performs the KWS pipeline. - * @param[in,out] ctx pointer to the application context object - * - * @return KWSOutput struct containing pointer to audio data where ASR should begin - * and how much data to process. - */ - static KWSOutput doKws(ApplicationContext& ctx) { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast( - (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? - arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); + (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)? + MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx); - KWSOutput output; + /* Output struct from doing KWS. */ + KWSOutput output {}; - auto& profiler = ctx.Get("profiler"); - auto& kwsModel = ctx.Get("kwsmodel"); if (!kwsModel.IsInited()) { printf_err("KWS model has not been initialised\n"); return output; } - const int kwsFrameLength = ctx.Get("kwsframeLength"); - const int kwsFrameStride = ctx.Get("kwsframeStride"); - const float kwsScoreThreshold = ctx.Get("kwsscoreThreshold"); - - TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); - + TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); if (!kwsInputTensor->dims) { printf_err("Invalid input tensor dims\n"); return output; @@ -126,63 +102,32 @@ namespace app { return output; } - const uint32_t kwsNumMfccFeats = ctx.Get("kwsNumMfcc"); - const uint32_t kwsNumAudioWindows = ctx.Get("kwsNumAudioWins"); - - audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength); - kwsMfcc.Init(); - - /* Deduce the data length required for 1 KWS inference from the network parameters. */ - auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride + - (kwsFrameLength - kwsFrameStride); - auto kwsMfccWindowSize = kwsFrameLength; - auto kwsMfccWindowStride = kwsFrameStride; - - /* We are choosing to move by half the window size => for a 1 second window size, - * this means an overlap of 0.5 seconds. */ - auto kwsAudioDataStride = kwsAudioDataWindowSize / 2; - - info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize); - - /* Stride must be multiple of mfcc features window stride to re-use features. */ - if (0 != kwsAudioDataStride % kwsMfccWindowStride) { - kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride; - } - - auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride; + /* Get input shape for feature extraction. */ + TfLiteIntArray* inputShape = kwsModel.GetInputShape(0); + const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx]; /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; - auto currentIndex = ctx.Get("clipIndex"); + /* Set up pre and post-processing. */ + KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames, + kwsMfccFrameLength, kwsMfccFrameStride); - /* Creating a mfcc features sliding window for the data required for 1 inference. */ - auto kwsAudioMFCCWindowSlider = audio::SlidingWindow( - get_audio_array(currentIndex), - kwsAudioDataWindowSize, kwsMfccWindowSize, - kwsMfccWindowStride); + std::vector singleInfResult; + KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get("kwsClassifier"), + ctx.Get&>("kwsLabels"), + singleInfResult); /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow( get_audio_array(currentIndex), get_audio_array_size(currentIndex), - kwsAudioDataWindowSize, kwsAudioDataStride); - - /* Calculate number of the feature vectors in the window overlap region. - * These feature vectors will be reused.*/ - size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1 - - kwsMfccVectorsInAudioStride; + preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); - auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor, - numberOfReusedFeatureVectors); - - if (!kwsMfccFeatureCalc){ - return output; - } - - /* Container for KWS results. */ - std::vector kwsResults; + /* Declare a container to hold kws results from across the whole audio clip. */ + std::vector finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running KWS inference... "}; @@ -197,70 +142,56 @@ namespace app { while (audioDataSlider.HasNext()) { const int16_t* inferenceWindow = audioDataSlider.Next(); - /* We moved to the next window - set the features sliding to the new address. */ - kwsAudioMFCCWindowSlider.Reset(inferenceWindow); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (kwsAudioMFCCWindowSlider.HasNext()) { - const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next(); - std::vector kwsMfccAudioData = - std::vector(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize); - - /* Compute features for this window and write them to input tensor. */ - kwsMfccFeatureCalc(kwsMfccAudioData, - kwsAudioMFCCWindowSlider.Index(), - useCache, - kwsMfccVectorsInAudioStride); - } + preProcess.m_audioWindowIndex = audioDataSlider.Index(); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, - audioDataSlider.TotalStrides() + 1); + /* Run the pre-processing, inference and post-processing. */ + if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) { + printf_err("KWS Pre-processing failed."); + return output; + } - /* Run inference over this audio clip sliding window. */ if (!RunInference(kwsModel, profiler)) { - printf_err("KWS inference failed\n"); + printf_err("KWS Inference failed."); return output; } - std::vector kwsClassificationResult; - auto& kwsClassifier = ctx.Get("kwsclassifier"); + if (!postProcess.DoPostProcess()) { + printf_err("KWS Post-processing failed."); + return output; + } - kwsClassifier.GetClassificationResults( - kwsOutputTensor, kwsClassificationResult, - ctx.Get&>("kwslabels"), 1, true); + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); - kwsResults.emplace_back( - kws::KwsResult( - kwsClassificationResult, - audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride, - audioDataSlider.Index(), kwsScoreThreshold) - ); + /* Add results from this window to our final results vector. */ + finalResults.emplace_back( + kws::KwsResult(singleInfResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride, + audioDataSlider.Index(), kwsScoreThreshold)); - /* Keyword detected. */ - if (kwsClassificationResult[0].m_label == ctx.Get("triggerkeyword")) { - output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize; + /* Break out when trigger keyword is detected. */ + if (singleInfResult[0].m_label == ctx.Get("triggerKeyword") + && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) { + output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize; output.asrAudioSamples = get_audio_array_size(currentIndex) - (audioDataSlider.NextWindowStartIndex() - - kwsAudioDataStride + kwsAudioDataWindowSize); + preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize); break; } #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(kwsOutputTensor); + DumpTensor(kwsOutputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - if (!PresentInferenceResult(kwsResults)) { + if (!PresentInferenceResult(finalResults)) { return output; } @@ -271,41 +202,41 @@ namespace app { } /** - * @brief Performs the ASR pipeline. - * - * @param[in,out] ctx pointer to the application context object - * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin - * and how much data to process - * @return bool true if pipeline executed without failure - */ - static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { + * @brief Performs the ASR pipeline. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin + * and how much data to process. + * @return true if pipeline executed without failure. + **/ + static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) + { + auto& asrModel = ctx.Get("asrModel"); + auto& profiler = ctx.Get("profiler"); + auto asrMfccFrameLen = ctx.Get("asrFrameLength"); + auto asrMfccFrameStride = ctx.Get("asrFrameStride"); + auto asrScoreThreshold = ctx.Get("asrScoreThreshold"); + auto asrInputCtxLen = ctx.Get("ctxLen"); + constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; - auto& profiler = ctx.Get("profiler"); - hal_lcd_clear(COLOR_BLACK); - - /* Get model reference. */ - auto& asrModel = ctx.Get("asrmodel"); if (!asrModel.IsInited()) { printf_err("ASR model has not been initialised\n"); return false; } - /* Get score threshold to be applied for the classifier (post-inference). */ - auto asrScoreThreshold = ctx.Get("asrscoreThreshold"); + hal_lcd_clear(COLOR_BLACK); - /* Dimensions of the tensor should have been verified by the callee. */ + /* Get Input and Output tensors for pre/post processing. */ TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); - const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - /* Populate ASR MFCC related parameters. */ - auto asrMfccParamsWinLen = ctx.Get("asrframeLength"); - auto asrMfccParamsWinStride = ctx.Get("asrframeStride"); + /* Get input shape. Dimensions of the tensor should have been verified by + * the callee. */ + TfLiteIntArray* inputShape = asrModel.GetInputShape(0); - /* Populate ASR inference context and inner lengths for input. */ - auto asrInputCtxLen = ctx.Get("ctxLen"); + + const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx]; const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); /* Make sure the input tensor supports the above context and inner lengths. */ @@ -316,18 +247,9 @@ namespace app { } /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) * - asrMfccParamsWinStride + (asrMfccParamsWinLen); - const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride; - const float asrAudioParamsSecondsPerSample = - (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); - - /* Get pre/post-processing objects */ - auto& asrPrep = ctx.Get("preprocess"); - auto& asrPostp = ctx.Get("postprocess"); - - /* Set default reduction axis for post-processing. */ - const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen); + const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride; + const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq; /* Get the remaining audio buffer and respective size from KWS results. */ const int16_t* audioArr = kwsOutput.asrAudioStart; @@ -335,9 +257,9 @@ namespace app { /* Audio clip must have enough samples to produce 1 MFCC feature. */ std::vector audioBuffer = std::vector(audioArr, audioArr + audioArrSize); - if (audioArrSize < asrMfccParamsWinLen) { + if (audioArrSize < asrMfccFrameLen) { printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n", - asrMfccParamsWinLen); + asrMfccFrameLen); return false; } @@ -345,26 +267,38 @@ namespace app { auto audioDataSlider = audio::FractionalSlidingWindow( audioBuffer.data(), audioBuffer.size(), - asrAudioParamsWinLen, - asrAudioParamsWinStride); + asrAudioDataWindowLen, + asrAudioDataWindowStride); /* Declare a container for results. */ - std::vector asrResults; + std::vector asrResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running ASR inference... "}; - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - size_t asrInferenceWindowLen = asrAudioParamsWinLen; - + size_t asrInferenceWindowLen = asrAudioDataWindowLen; + + /* Set up pre and post-processing objects. */ + AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + asrMfccFrameLen, asrMfccFrameStride); + + std::vector singleInfResult; + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen); + AsrPostProcess asrPostProcess = AsrPostProcess( + asrOutputTensor, ctx.Get("asrClassifier"), + ctx.Get&>("asrLabels"), + singleInfResult, outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx + ); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { /* If not enough audio see how much can be sent for processing. */ size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); - if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) { + if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) { asrInferenceWindowLen = audioBuffer.size() - nextStartIndex; } @@ -373,8 +307,11 @@ namespace app { info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, static_cast(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); - /* Calculate MFCCs, deltas and populate the input tensor. */ - asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor); + /* Run the pre-processing, inference and post-processing. */ + if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) { + printf_err("ASR pre-processing failed."); + return false; + } /* Run inference over this audio clip sliding window. */ if (!RunInference(asrModel, profiler)) { @@ -382,24 +319,28 @@ namespace app { return false; } - /* Post-process. */ - asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext()); + /* Post processing needs to know if we are on the last audio window. */ + asrPostProcess.m_lastIteration = !audioDataSlider.HasNext(); + if (!asrPostProcess.DoPostProcess()) { + printf_err("ASR post-processing failed."); + return false; + } /* Get results. */ std::vector asrClassificationResult; - auto& asrClassifier = ctx.Get("asrclassifier"); + auto& asrClassifier = ctx.Get("asrClassifier"); asrClassifier.GetClassificationResults( asrOutputTensor, asrClassificationResult, - ctx.Get&>("asrlabels"), 1); + ctx.Get&>("asrLabels"), 1); asrResults.emplace_back(asr::AsrResult(asrClassificationResult, (audioDataSlider.Index() * asrAudioParamsSecondsPerSample * - asrAudioParamsWinStride), + asrAudioDataWindowStride), audioDataSlider.Index(), asrScoreThreshold)); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); + armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ /* Erase */ @@ -417,7 +358,7 @@ namespace app { return true; } - /* Audio inference classification handler. */ + /* KWS and ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { hal_lcd_clear(COLOR_BLACK); @@ -434,13 +375,14 @@ namespace app { do { KWSOutput kwsOutput = doKws(ctx); if (!kwsOutput.executionSuccess) { + printf_err("KWS failed\n"); return false; } if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { - info("Keyword spotted\n"); + info("Trigger keyword spotted\n"); if(!doAsr(ctx, kwsOutput)) { - printf_err("ASR failed"); + printf_err("ASR failed\n"); return false; } } @@ -452,7 +394,6 @@ namespace app { return true; } - static bool PresentInferenceResult(std::vector& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -464,33 +405,31 @@ namespace app { /* Display each result. */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - for (uint32_t i = 0; i < results.size(); ++i) { - + for (auto & result : results) { std::string topKeyword{""}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { - topKeyword = results[i].m_resultVec[0].m_label; - score = results[i].m_resultVec[0].m_normalisedVal; + if (!result.m_resultVec.empty()) { + topKeyword = result.m_resultVec[0].m_label; + score = result.m_resultVec[0].m_normalisedVal; } std::string resultStr = - std::string{"@"} + std::to_string(results[i].m_timeStamp) + + std::string{"@"} + std::to_string(result.m_timeStamp) + std::string{"s: "} + topKeyword + std::string{" ("} + std::to_string(static_cast(score * 100)) + std::string{"%)"}; - hal_lcd_display_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); + hal_lcd_display_text(resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); rowIdx1 += dataPsnTxtYIncr; info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n", - results[i].m_timeStamp, results[i].m_inferenceNumber, - results[i].m_threshold); - for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) { + result.m_timeStamp, result.m_inferenceNumber, + result.m_threshold); + for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) { info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j, - results[i].m_resultVec[j].m_label.c_str(), - results[i].m_resultVec[j].m_normalisedVal); + result.m_resultVec[j].m_label.c_str(), + result.m_resultVec[j].m_normalisedVal); } } @@ -523,143 +462,12 @@ namespace app { std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text( - finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); info("Final result: %s\n", finalResultStr.c_str()); return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - **/ - template - std::function&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function (std::vector& )> compute) - { - /* Feature cache to be captured by lambda function. */ - static std::vector> featureCache = std::vector>(cacheSize); - - return [=](std::vector& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex) - { - T* tensorData = tflite::GetTensorData(inputTensor); - std::vector features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. - */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size(); - auto sizeBytes = sizeof(T) * size; - std::memcpy(tensorData + (index * size), features.data(), sizeBytes); - - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex) { - featureCache[index - featuresOverlapIndex] = std::move(features); - } - }; - } - - template std::function&, size_t , bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector& )> compute); - - template std::function&, size_t , bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector& )> compute); - - template std::function&, size_t , bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function (std::vector& )> compute); - - template std::function&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function(std::vector&)> compute); - - - static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) - { - std::function&, size_t, bool, size_t)> mfccFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto* quantParams = (TfLiteAffineQuantization*) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &mfcc](std::vector& audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteUInt8: { - mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &mfcc](std::vector& audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - case kTfLiteInt16: { - mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [=, &mfcc](std::vector& audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, - quantScale, - quantOffset); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - - - } else { - mfccFeatureCalc = mfccFeatureCalc = FeatureCalc(inputTensor, - cacheSize, - [&mfcc](std::vector& audioDataWindow) { - return mfcc.MfccCompute(audioDataWindow); - }); - } - return mfccFeatureCalc; - } } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc index 2a76b1b..42f434e 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,62 +15,71 @@ * limitations under the License. */ #include "Wav2LetterPostprocess.hpp" + #include "Wav2LetterModel.hpp" #include "log_macros.h" +#include + namespace arm { namespace app { -namespace audio { -namespace asr { - - Postprocess::Postprocess(const uint32_t contextLen, - const uint32_t innerLen, - const uint32_t blankTokenIdx) - : m_contextLen(contextLen), - m_innerLen(innerLen), - m_totalLen(2 * this->m_contextLen + this->m_innerLen), + + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector& labels, std::vector& results, + const uint32_t outputContextLen, + const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx + ): + m_classifier(classifier), + m_outputTensor(outputTensor), + m_labels{labels}, + m_results(results), + m_outputContextLen(outputContextLen), m_countIterations(0), - m_blankTokenIdx(blankTokenIdx) - {} + m_blankTokenIdx(blankTokenIdx), + m_reductionAxisIdx(reductionAxisIdx) + { + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); + } - bool Postprocess::Invoke(TfLiteTensor* tensor, - const uint32_t axisIdx, - const bool lastIteration) + bool AsrPostProcess::DoPostProcess() { /* Basic checks. */ - if (!this->IsInputValid(tensor, axisIdx)) { + if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { return false; } /* Irrespective of tensor type, we use unsigned "byte" */ - uint8_t* ptrData = tflite::GetTensorData(tensor); - const uint32_t elemSz = this->GetTensorElementSize(tensor); + auto* ptrData = tflite::GetTensorData(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); /* Other sanity checks. */ if (0 == elemSz) { printf_err("Tensor type not supported for post processing\n"); return false; - } else if (elemSz * this->m_totalLen > tensor->bytes) { + } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) { printf_err("Insufficient number of tensor bytes\n"); return false; } /* Which axis do we need to process? */ - switch (axisIdx) { - case arm::app::Wav2LetterModel::ms_outputRowsIdx: - return this->EraseSectionsRowWise(ptrData, - elemSz * - tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx], - lastIteration); + switch (this->m_reductionAxisIdx) { + case Wav2LetterModel::ms_outputRowsIdx: + this->EraseSectionsRowWise( + ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx], + this->m_lastIteration); + break; default: - printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx); + printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx); + return false; } + this->m_classifier.GetClassificationResults(this->m_outputTensor, + this->m_results, this->m_labels, 1); - return false; + return true; } - bool Postprocess::IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const { if (nullptr == tensor) { return false; @@ -84,25 +93,23 @@ namespace asr { if (static_cast(this->m_totalLen) != tensor->dims->data[axisIdx]) { - printf_err("Unexpected tensor dimension for axis %d, \n", - tensor->dims->data[axisIdx]); + printf_err("Unexpected tensor dimension for axis %d, got %d, \n", + axisIdx, tensor->dims->data[axisIdx]); return false; } return true; } - uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor) + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) { switch(tensor->type) { case kTfLiteUInt8: - return 1; case kTfLiteInt8: return 1; case kTfLiteInt16: return 2; case kTfLiteInt32: - return 4; case kTfLiteFloat32: return 4; default: @@ -113,30 +120,30 @@ namespace asr { return 0; } - bool Postprocess::EraseSectionsRowWise( - uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration) + bool AsrPostProcess::EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) { /* In this case, the "zero-ing" is quite simple as the region * to be zeroed sits in contiguous memory (row-major). */ - const uint32_t eraseLen = strideSzBytes * this->m_contextLen; + const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen; /* Erase left context? */ if (this->m_countIterations > 0) { /* Set output of each classification window to the blank token. */ std::memset(ptrData, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } /* Erase right context? */ if (false == lastIteration) { - uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen)); + uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen)); /* Set output of each classification window to the blank token. */ std::memset(rightCtxPtr, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } @@ -150,7 +157,58 @@ namespace asr { return true; } -} /* namespace asr */ -} /* namespace audio */ + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) + { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); + if (inputRows == 0) { + printf_err("Error getting number of input rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_inputRowsIdx); + } + return inputRows; + } + + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + { + const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + } + + /* Watching for underflow. */ + int innerLen = (outputRows - (2 * outputCtxLen)); + + return std::max(innerLen, 0); + } + + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + { + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + return 0; + } + + const float inOutRowRatio = static_cast(inputRows) / + static_cast(outputRows); + + return std::round(static_cast(inputCtxLen) / inOutRowRatio); + } + } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc index d3f3579..92b0631 100644 --- a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc +++ b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,41 +20,35 @@ #include "TensorFlowLiteMicro.hpp" #include -#include +#include namespace arm { namespace app { -namespace audio { -namespace asr { - - Preprocess::Preprocess( - const uint32_t numMfccFeatures, - const uint32_t windowLen, - const uint32_t windowStride, - const uint32_t numMfccVectors): - m_mfcc(numMfccFeatures, windowLen), - m_mfccBuf(numMfccFeatures, numMfccVectors), - m_delta1Buf(numMfccFeatures, numMfccVectors), - m_delta2Buf(numMfccFeatures, numMfccVectors), - m_windowLen(windowLen), - m_windowStride(windowStride), + + AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride + ): + m_mfcc(numMfccFeatures, mfccWindowLen), + m_inputTensor(inputTensor), + m_mfccBuf(numMfccFeatures, numFeatureFrames), + m_delta1Buf(numMfccFeatures, numFeatureFrames), + m_delta2Buf(numMfccFeatures, numFeatureFrames), + m_mfccWindowLen(mfccWindowLen), + m_mfccWindowStride(mfccWindowStride), m_numMfccFeats(numMfccFeatures), - m_numFeatVectors(numMfccVectors), - m_window() + m_numFeatureFrames(numFeatureFrames) { - if (numMfccFeatures > 0 && windowLen > 0) { + if (numMfccFeatures > 0 && mfccWindowLen > 0) { this->m_mfcc.Init(); } } - bool Preprocess::Invoke( - const int16_t* audioData, - const uint32_t audioDataLen, - TfLiteTensor* tensor) + bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) { - this->m_window = SlidingWindow( - audioData, audioDataLen, - this->m_windowLen, this->m_windowStride); + this->m_mfccSlidingWindow = audio::SlidingWindow( + static_cast(audioData), audioDataLen, + this->m_mfccWindowLen, this->m_mfccWindowStride); uint32_t mfccBufIdx = 0; @@ -62,12 +56,12 @@ namespace asr { std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f); std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f); - /* While we can slide over the window. */ - while (this->m_window.HasNext()) { - const int16_t* mfccWindow = this->m_window.Next(); + /* While we can slide over the audio. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); auto mfccAudioData = std::vector( mfccWindow, - mfccWindow + this->m_windowLen); + mfccWindow + this->m_mfccWindowLen); auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData); for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) { this->m_mfccBuf(i, mfccBufIdx) = mfcc[i]; @@ -76,11 +70,11 @@ namespace asr { } /* Pad MFCC if needed by adding MFCC for zeros. */ - if (mfccBufIdx != this->m_numFeatVectors) { - std::vector zerosWindow = std::vector(this->m_windowLen, 0); + if (mfccBufIdx != this->m_numFeatureFrames) { + std::vector zerosWindow = std::vector(this->m_mfccWindowLen, 0); std::vector mfccZeros = this->m_mfcc.MfccCompute(zerosWindow); - while (mfccBufIdx != this->m_numFeatVectors) { + while (mfccBufIdx != this->m_numFeatureFrames) { memcpy(&this->m_mfccBuf(0, mfccBufIdx), mfccZeros.data(), sizeof(float) * m_numMfccFeats); ++mfccBufIdx; @@ -88,41 +82,39 @@ namespace asr { } /* Compute first and second order deltas from MFCCs. */ - this->ComputeDeltas(this->m_mfccBuf, - this->m_delta1Buf, - this->m_delta2Buf); + AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); - /* Normalise. */ - this->Normalise(); + /* Standardize calculated features. */ + this->Standarize(); /* Quantise. */ - QuantParams quantParams = GetTensorQuantParams(tensor); + QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor); if (0 == quantParams.scale) { printf_err("Quantisation scale can't be 0\n"); return false; } - switch(tensor->type) { + switch(this->m_inputTensor->type) { case kTfLiteUInt8: return this->Quantise( - tflite::GetTensorData(tensor), tensor->bytes, + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); case kTfLiteInt8: return this->Quantise( - tflite::GetTensorData(tensor), tensor->bytes, + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); default: printf_err("Unsupported tensor type %s\n", - TfLiteTypeGetName(tensor->type)); + TfLiteTypeGetName(this->m_inputTensor->type)); } return false; } - bool Preprocess::ComputeDeltas(Array2d& mfcc, - Array2d& delta1, - Array2d& delta2) + bool AsrPreProcess::ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2) { const std::vector delta1Coeffs = {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, @@ -148,11 +140,11 @@ namespace asr { /* Iterate through features in MFCC vector. */ for (size_t i = 0; i < numFeatures; ++i) { /* For each feature, iterate through time (t) samples representing feature evolution and - * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels. + * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels. * Convolution padding = valid, result size is `time length - kernel length + 1`. * The result is padded with 0 from both sides to match the size of initial time samples data. * - * For the small filter, conv1d implementation as a simple loop is efficient enough. + * For the small filter, conv1D implementation as a simple loop is efficient enough. * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32. */ @@ -175,20 +167,10 @@ namespace asr { return true; } - float Preprocess::GetMean(Array2d& vec) - { - return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); - } - - float Preprocess::GetStdDev(Array2d& vec, const float mean) - { - return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); - } - - void Preprocess::NormaliseVec(Array2d& vec) + void AsrPreProcess::StandardizeVecF32(Array2d& vec) { - auto mean = Preprocess::GetMean(vec); - auto stddev = Preprocess::GetStdDev(vec, mean); + auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); debug("Mean: %f, Stddev: %f\n", mean, stddev); if (stddev == 0) { @@ -204,14 +186,14 @@ namespace asr { } } - void Preprocess::Normalise() + void AsrPreProcess::Standarize() { - Preprocess::NormaliseVec(this->m_mfccBuf); - Preprocess::NormaliseVec(this->m_delta1Buf); - Preprocess::NormaliseVec(this->m_delta2Buf); + AsrPreProcess::StandardizeVecF32(this->m_mfccBuf); + AsrPreProcess::StandardizeVecF32(this->m_delta1Buf); + AsrPreProcess::StandardizeVecF32(this->m_delta2Buf); } - float Preprocess::GetQuantElem( + float AsrPreProcess::GetQuantElem( const float elem, const float quantScale, const int quantOffset, @@ -222,7 +204,5 @@ namespace asr { return std::min(std::max(val, minVal), maxVal); } -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake index b3fe020..40df4d7 100644 --- a/source/use_case/kws_asr/usecase.cmake +++ b/source/use_case/kws_asr/usecase.cmake @@ -1,5 +1,5 @@ #---------------------------------------------------------------------------- -# Copyright (c) 2021 Arm Limited. All rights reserved. +# Copyright (c) 2021-2022 Arm Limited. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -59,7 +59,7 @@ USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_KWS "Specify the score threshold [0.0, 1.0) that must be applied to the KWS results for a label to be deemed valid." - 0.9 + 0.7 STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [0.0, 1.0) that must be applied to the ASR results for a label to be deemed valid." -- cgit v1.2.1