diff options
Diffstat (limited to 'source/use_case/asr/include')
-rw-r--r-- | source/use_case/asr/include/AsrClassifier.hpp | 10 | ||||
-rw-r--r-- | source/use_case/asr/include/Wav2LetterModel.hpp | 1 | ||||
-rw-r--r-- | source/use_case/asr/include/Wav2LetterPostprocess.hpp | 15 | ||||
-rw-r--r-- | source/use_case/asr/include/Wav2LetterPreprocess.hpp | 28 |
4 files changed, 28 insertions, 26 deletions
diff --git a/source/use_case/asr/include/AsrClassifier.hpp b/source/use_case/asr/include/AsrClassifier.hpp index 67a200e..a07a721 100644 --- a/source/use_case/asr/include/AsrClassifier.hpp +++ b/source/use_case/asr/include/AsrClassifier.hpp @@ -35,10 +35,10 @@ namespace app { * @param[in] use_softmax Whether softmax scaling should be applied to model output. * @return true if successful, false otherwise. **/ - bool GetClassificationResults( - TfLiteTensor* outputTensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax = false) override; + bool GetClassificationResults(TfLiteTensor* outputTensor, + std::vector<ClassificationResult>& vecResults, + const std::vector<std::string>& labels, + uint32_t topNCount, bool use_softmax = false) override; private: /** @@ -54,7 +54,7 @@ namespace app { template<typename T> bool GetTopResults(TfLiteTensor* tensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint); + const std::vector<std::string>& labels, double scale, double zeroPoint); }; } /* namespace app */ diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp index 895df2b..0078e44 100644 --- a/source/use_case/asr/include/Wav2LetterModel.hpp +++ b/source/use_case/asr/include/Wav2LetterModel.hpp @@ -36,6 +36,7 @@ namespace app { static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; + /* Model specific constants. */ static constexpr uint32_t ms_blankTokenIdx = 28; static constexpr uint32_t ms_numMfccFeatures = 13; diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp index 45defa5..446014d 100644 --- a/source/use_case/asr/include/Wav2LetterPostprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp @@ -30,23 +30,24 @@ namespace app { * @brief Helper class to manage tensor post-processing for "wav2letter" * output. */ - class ASRPostProcess : public BasePostProcess { + class AsrPostProcess : public BasePostProcess { public: bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ /** * @brief Constructor - * @param[in] outputTensor Pointer to the output Tensor. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Object used to get top N results from classification. * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in/out] result Vector of classification results to store decoded outputs. * @param[in] outputContextLen Left/right context length for output tensor. * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. * @param[in] reductionAxis The axis that the logits of each time step is on. **/ - ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor, - const std::vector<std::string>& labels, asr::ResultVec& result, - uint32_t outputContextLen, - uint32_t blankTokenIdx, uint32_t reductionAxis); + AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector<std::string>& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); /** * @brief Should perform post-processing of the result of inference then diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp index 8c12b3d..dc9a415 100644 --- a/source/use_case/asr/include/Wav2LetterPreprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp @@ -31,22 +31,22 @@ namespace app { * for ASR. */ using AudioWindow = audio::SlidingWindow<const int16_t>; - class ASRPreProcess : public BasePreProcess { + class AsrPreProcess : public BasePreProcess { public: /** * @brief Constructor. * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. - * @param[in] mfccWindowStride Number of MFCC vectors that need to be calculated - * for an inference. */ - ASRPreProcess(TfLiteTensor* inputTensor, - uint32_t numMfccFeatures, - uint32_t audioWindowLen, - uint32_t mfccWindowLen, - uint32_t mfccWindowStride); + AsrPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t numFeatureFrames, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); /** * @brief Calculates the features required from audio data. This @@ -130,9 +130,9 @@ namespace app { } /* Populate. */ - T * outputBufMfcc = outputBuf; - T * outputBufD1 = outputBuf + this->m_numMfccFeats; - T * outputBufD2 = outputBufD1 + this->m_numMfccFeats; + T* outputBufMfcc = outputBuf; + T* outputBufD1 = outputBuf + this->m_numMfccFeats; + T* outputBufD2 = outputBufD1 + this->m_numMfccFeats; const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ const float minVal = std::numeric_limits<T>::min(); @@ -141,13 +141,13 @@ namespace app { /* Need to transpose while copying and concatenating the tensor. */ for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast<T>(ASRPreProcess::GetQuantElem( + *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_mfccBuf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast<T>(ASRPreProcess::GetQuantElem( + *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_delta1Buf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast<T>(ASRPreProcess::GetQuantElem( + *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem( this->m_delta2Buf(i, j), quantScale, quantOffset, minVal, maxVal)); } |