diff options
Diffstat (limited to 'source/use_case/asr/include/Wav2LetterPreprocess.hpp')
-rw-r--r-- | source/use_case/asr/include/Wav2LetterPreprocess.hpp | 96 |
1 files changed, 37 insertions, 59 deletions
diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp index 13d1589..8c12b3d 100644 --- a/source/use_case/asr/include/Wav2LetterPreprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,49 +21,44 @@ #include "Wav2LetterMfcc.hpp" #include "AudioUtils.hpp" #include "DataStructures.hpp" +#include "BaseProcessing.hpp" #include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /* Class to facilitate pre-processing calculation for Wav2Letter model * for ASR. */ - using AudioWindow = SlidingWindow <const int16_t>; + using AudioWindow = audio::SlidingWindow<const int16_t>; - class Preprocess { + class ASRPreProcess : public BasePreProcess { public: /** * @brief Constructor. - * @param[in] numMfccFeatures Number of MFCC features per window. - * @param[in] windowLen Number of elements in a window. - * @param[in] windowStride Stride (in number of elements) for - * moving the window. - * @param[in] numMfccVectors Number of MFCC vectors per window. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. + * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. + * @param[in] mfccWindowStride Number of MFCC vectors that need to be calculated + * for an inference. */ - Preprocess( - uint32_t numMfccFeatures, - uint32_t windowLen, - uint32_t windowStride, - uint32_t numMfccVectors); - Preprocess() = delete; - ~Preprocess() = default; + ASRPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t audioWindowLen, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); /** * @brief Calculates the features required from audio data. This * includes MFCC, first and second order deltas, * normalisation and finally, quantisation. The tensor is - * populated with feature from a given window placed along + * populated with features from a given window placed along * in a single row. * @param[in] audioData Pointer to the first element of audio data. * @param[in] audioDataLen Number of elements in the audio data. - * @param[in] tensor Tensor to be populated. * @return true if successful, false in case of error. */ - bool Invoke(const int16_t * audioData, - uint32_t audioDataLen, - TfLiteTensor * tensor); + bool DoPreProcess(const void* audioData, size_t audioDataLen) override; protected: /** @@ -80,32 +75,16 @@ namespace asr { Array2d<float>& delta2); /** - * @brief Given a 2D vector of floats, computes the mean. - * @param[in] vec Vctor of vector of floats. - * @return Mean value. - */ - static float GetMean(Array2d<float>& vec); - - /** - * @brief Given a 2D vector of floats, computes the stddev. - * @param[in] vec Vector of vector of floats. - * @param[in] mean Mean value of the vector passed in. - * @return stddev value. - */ - static float GetStdDev(Array2d<float>& vec, - const float mean); - - /** - * @brief Given a 2D vector of floats, normalises it using - * the mean and the stddev. + * @brief Given a 2D vector of floats, rescale it to have mean of 0 and + * standard deviation of 1. * @param[in,out] vec Vector of vector of floats. */ - static void NormaliseVec(Array2d<float>& vec); + static void StandardizeVecF32(Array2d<float>& vec); /** - * @brief Normalises the MFCC and delta buffers. + * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1. */ - void Normalise(); + void Standarize(); /** * @brief Given the quantisation and data type limits, computes @@ -139,7 +118,7 @@ namespace asr { */ template <typename T> bool Quantise( - T * outputBuf, + T* outputBuf, const uint32_t outputBufSz, const float quantScale, const int quantOffset) @@ -160,15 +139,15 @@ namespace asr { const float maxVal = std::numeric_limits<T>::max(); /* Need to transpose while copying and concatenating the tensor. */ - for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) { + for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast<T>(Preprocess::GetQuantElem( + *outputBufMfcc++ = static_cast<T>(ASRPreProcess::GetQuantElem( this->m_mfccBuf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast<T>(Preprocess::GetQuantElem( + *outputBufD1++ = static_cast<T>(ASRPreProcess::GetQuantElem( this->m_delta1Buf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast<T>(Preprocess::GetQuantElem( + *outputBufD2++ = static_cast<T>(ASRPreProcess::GetQuantElem( this->m_delta2Buf(i, j), quantScale, quantOffset, minVal, maxVal)); } @@ -181,23 +160,22 @@ namespace asr { } private: - Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + TfLiteTensor* m_inputTensor; /* Model input tensor. */ /* Actual buffers to be populated. */ - Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */ - Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ - Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ - uint32_t m_windowLen; /* Window length for MFCC. */ - uint32_t m_windowStride; /* Window stride len for MFCC. */ - uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ - uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */ - AudioWindow m_window; /* Sliding window. */ + uint32_t m_mfccWindowLen; /* Window length for MFCC. */ + uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */ + uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */ + AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */ }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ |