summaryrefslogtreecommitdiff
path: root/source/use_case/kws_asr/include
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-05-04 09:45:02 +0100
committerRichard Burton <richard.burton@arm.com>2022-05-04 09:45:02 +0100
commit4e002791bc6781b549c6951cfe44f918289d7e82 (patch)
treeb639243b5fa433657c207783a384bad1ed248536 /source/use_case/kws_asr/include
parentdd6d07b24bbf9023ebe8e8927be8aac3291d0f58 (diff)
downloadml-embedded-evaluation-kit-4e002791bc6781b549c6951cfe44f918289d7e82.tar.gz
MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's
Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37
Diffstat (limited to 'source/use_case/kws_asr/include')
-rw-r--r--source/use_case/kws_asr/include/KwsProcessing.hpp138
-rw-r--r--source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp10
-rw-r--r--source/use_case/kws_asr/include/Wav2LetterModel.hpp12
-rw-r--r--source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp117
-rw-r--r--source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp138
5 files changed, 270 insertions, 145 deletions
diff --git a/source/use_case/kws_asr/include/KwsProcessing.hpp b/source/use_case/kws_asr/include/KwsProcessing.hpp
new file mode 100644
index 0000000..d3de3b3
--- /dev/null
+++ b/source/use_case/kws_asr/include/KwsProcessing.hpp
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef KWS_PROCESSING_HPP
+#define KWS_PROCESSING_HPP
+
+#include <AudioUtils.hpp>
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "Classifier.hpp"
+#include "MicroNetKwsMfcc.hpp"
+
+#include <functional>
+
+namespace arm {
+namespace app {
+
+ /**
+ * @brief Pre-processing class for Keyword Spotting use case.
+ * Implements methods declared by BasePreProcess and anything else needed
+ * to populate input tensors ready for inference.
+ */
+ class KwsPreProcess : public BasePreProcess {
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] numFeatures How many MFCC features to use.
+ * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
+ * for an inference.
+ * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
+ * sliding a window through the audio sample.
+ * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
+ **/
+ explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames,
+ int mfccFrameLength, int mfccFrameStride);
+
+ /**
+ * @brief Should perform pre-processing of 'raw' input audio data and load it into
+ * TFLite Micro input tensors ready for inference.
+ * @param[in] input Pointer to the data that pre-processing will work on.
+ * @param[in] inputSize Size of the input data.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */
+ size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */
+ size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
+
+ private:
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
+ const int m_mfccFrameLength;
+ const int m_mfccFrameStride;
+ const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */
+
+ audio::MicroNetKwsMFCC m_mfcc;
+ audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
+ size_t m_numMfccVectorsInAudioStride;
+ size_t m_numReusedMfccVectors;
+ std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
+
+ /**
+ * @brief Returns a function to perform feature calculation and populates input tensor data with
+ * MFCC data.
+ *
+ * Input tensor data type check is performed to choose correct MFCC feature data type.
+ * If tensor has an integer data type then original features are quantised.
+ *
+ * Warning: MFCC calculator provided as input must have the same life scope as returned function.
+ *
+ * @param[in] mfcc MFCC feature calculator.
+ * @param[in,out] inputTensor Input tensor pointer to store calculated features.
+ * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
+ * @return Function to be called providing audio sample and sliding window index.
+ */
+ std::function<void (std::vector<int16_t>&, int, bool, size_t)>
+ GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
+ TfLiteTensor* inputTensor,
+ size_t cacheSize);
+
+ template<class T>
+ std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
+ FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+ std::function<std::vector<T> (std::vector<int16_t>& )> compute);
+ };
+
+ /**
+ * @brief Post-processing class for Keyword Spotting use case.
+ * Implements methods declared by BasePostProcess and anything else needed
+ * to populate result vector.
+ */
+ class KwsPostProcess : public BasePostProcess {
+
+ private:
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ Classifier& m_kwsClassifier; /* KWS Classifier object. */
+ const std::vector<std::string>& m_labels; /* KWS Labels. */
+ std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] results Vector of classification results to store decoded outputs.
+ **/
+ KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
+ const std::vector<std::string>& labels,
+ std::vector<ClassificationResult>& results);
+
+ /**
+ * @brief Should perform post-processing of the result of inference then
+ * populate KWS result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+ };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* KWS_PROCESSING_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
index 43bd390..af6ba5f 100644
--- a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
+++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +24,7 @@ namespace app {
namespace audio {
/* Class to provide MicroNet specific MFCC calculation requirements. */
- class MicroNetMFCC : public MFCC {
+ class MicroNetKwsMFCC : public MFCC {
public:
static constexpr uint32_t ms_defaultSamplingFreq = 16000;
@@ -34,14 +34,14 @@ namespace audio {
static constexpr bool ms_defaultUseHtkMethod = true;
- explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen)
+ explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen)
: MFCC(MfccParams(
ms_defaultSamplingFreq, ms_defaultNumFbankBins,
ms_defaultMelLoFreq, ms_defaultMelHiFreq,
numFeats, frameLen, ms_defaultUseHtkMethod))
{}
- MicroNetMFCC() = delete;
- ~MicroNetMFCC() = default;
+ MicroNetKwsMFCC() = delete;
+ ~MicroNetKwsMFCC() = default;
};
} /* namespace audio */
diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp
index 7c327b3..0e1adc5 100644
--- a/source/use_case/kws_asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterModel.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -34,14 +34,18 @@ namespace arm {
namespace app {
class Wav2LetterModel : public Model {
-
+
public:
/* Indices for the expected model - based on input and output tensor shapes */
- static constexpr uint32_t ms_inputRowsIdx = 1;
- static constexpr uint32_t ms_inputColsIdx = 2;
+ static constexpr uint32_t ms_inputRowsIdx = 1;
+ static constexpr uint32_t ms_inputColsIdx = 2;
static constexpr uint32_t ms_outputRowsIdx = 2;
static constexpr uint32_t ms_outputColsIdx = 3;
+ /* Model specific constants. */
+ static constexpr uint32_t ms_blankTokenIdx = 28;
+ static constexpr uint32_t ms_numMfccFeatures = 13;
+
protected:
/** @brief Gets the reference to op resolver interface class. */
const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
index 029a641..d1bc9a2 100644
--- a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,88 +14,95 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_WAV2LET_POSTPROC_HPP
-#define KWS_ASR_WAV2LET_POSTPROC_HPP
+#ifndef KWS_ASR_WAV2LETTER_POSTPROCESS_HPP
+#define KWS_ASR_WAV2LETTER_POSTPROCESS_HPP
-#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers */
+#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */
+#include "BaseProcessing.hpp"
+#include "AsrClassifier.hpp"
+#include "AsrResult.hpp"
+#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/**
* @brief Helper class to manage tensor post-processing for "wav2letter"
* output.
*/
- class Postprocess {
+ class AsrPostProcess : public BasePostProcess {
public:
+ bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */
+
/**
- * @brief Constructor
- * @param[in] contextLen Left and right context length for
- * output tensor.
- * @param[in] innerLen This is the length of the section
- * between left and right context.
- * @param[in] blankTokenIdx Blank token index.
+ * @brief Constructor
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] result Vector of classification results to store decoded outputs.
+ * @param[in] outputContextLen Left/right context length for output tensor.
+ * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes.
+ * @param[in] reductionAxis The axis that the logits of each time step is on.
**/
- Postprocess(uint32_t contextLen,
- uint32_t innerLen,
- uint32_t blankTokenIdx);
-
- Postprocess() = delete;
- ~Postprocess() = default;
+ AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
+ const std::vector<std::string>& labels, asr::ResultVec& result,
+ uint32_t outputContextLen,
+ uint32_t blankTokenIdx, uint32_t reductionAxis);
/**
- * @brief Erases the required part of the tensor based
- * on context lengths set up during initialisation
- * @param[in] tensor Pointer to the tensor
- * @param[in] axisIdx Index of the axis on which erase is
- * performed.
- * @param[in] lastIteration Flag to signal is this is the
- * last iteration in which case
- * the right context is preserved.
- * @return true if successful, false otherwise.
- */
- bool Invoke(TfLiteTensor* tensor,
- uint32_t axisIdx,
- bool lastIteration = false);
+ * @brief Should perform post-processing of the result of inference then
+ * populate ASR result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+
+ /** @brief Gets the output inner length for post-processing. */
+ static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen);
+
+ /** @brief Gets the output context length (left/right) for post-processing. */
+ static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen);
+
+ /** @brief Gets the number of feature vectors to be computed. */
+ static uint32_t GetNumFeatureVectors(const Model& model);
private:
- uint32_t m_contextLen; /* Lengths of left and right contexts. */
- uint32_t m_innerLen; /* Length of inner context. */
- uint32_t m_totalLen; /* Total length of the required axis. */
- uint32_t m_countIterations; /* Current number of iterations. */
- uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ AsrClassifier& m_classifier; /* ASR Classifier object. */
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ const std::vector<std::string>& m_labels; /* ASR Labels. */
+ asr::ResultVec & m_results; /* Results vector for a single inference. */
+ uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */
+ uint32_t m_outputInnerLen; /* Length of output inner context. */
+ uint32_t m_totalLen; /* Total length of the required axis. */
+ uint32_t m_countIterations; /* Current number of iterations. */
+ uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */
+
/**
- * @brief Checks if the tensor and axis index are valid
- * inputs to the object - based on how it has been
- * initialised.
- * @return true if valid, false otherwise.
+ * @brief Checks if the tensor and axis index are valid
+ * inputs to the object - based on how it has been initialised.
+ * @return true if valid, false otherwise.
*/
bool IsInputValid(TfLiteTensor* tensor,
- const uint32_t axisIdx) const;
+ uint32_t axisIdx) const;
/**
- * @brief Gets the tensor data element size in bytes based
- * on the tensor type.
- * @return Size in bytes, 0 if not supported.
+ * @brief Gets the tensor data element size in bytes based
+ * on the tensor type.
+ * @return Size in bytes, 0 if not supported.
*/
- uint32_t GetTensorElementSize(TfLiteTensor* tensor);
+ static uint32_t GetTensorElementSize(TfLiteTensor* tensor);
/**
- * @brief Erases sections from the data assuming row-wise
- * arrangement along the context axis.
- * @return true if successful, false otherwise.
+ * @brief Erases sections from the data assuming row-wise
+ * arrangement along the context axis.
+ * @return true if successful, false otherwise.
*/
bool EraseSectionsRowWise(uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration);
-
+ uint32_t strideSzBytes,
+ bool lastIteration);
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_ASR_WAV2LET_POSTPROC_HPP */ \ No newline at end of file
+#endif /* KWS_ASR_WAV2LETTER_POSTPROCESS_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
index 3609c49..1224c23 100644
--- a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,56 +14,51 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_WAV2LET_PREPROC_HPP
-#define KWS_ASR_WAV2LET_PREPROC_HPP
+#ifndef KWS_ASR_WAV2LETTER_PREPROCESS_HPP
+#define KWS_ASR_WAV2LETTER_PREPROCESS_HPP
#include "Wav2LetterModel.hpp"
#include "Wav2LetterMfcc.hpp"
#include "AudioUtils.hpp"
#include "DataStructures.hpp"
+#include "BaseProcessing.hpp"
#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/* Class to facilitate pre-processing calculation for Wav2Letter model
* for ASR. */
- using AudioWindow = SlidingWindow <const int16_t>;
+ using AudioWindow = audio::SlidingWindow<const int16_t>;
- class Preprocess {
+ class AsrPreProcess : public BasePreProcess {
public:
/**
- * @brief Constructor
- * @param[in] numMfccFeatures Number of MFCC features per window.
- * @param[in] windowLen Number of elements in a window.
- * @param[in] windowStride Stride (in number of elements) for
- * moving the window.
- * @param[in] numMfccVectors Number of MFCC vectors per window.
- */
- Preprocess(
- uint32_t numMfccFeatures,
- uint32_t windowLen,
- uint32_t windowStride,
- uint32_t numMfccVectors);
- Preprocess() = delete;
- ~Preprocess() = default;
+ * @brief Constructor.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] numMfccFeatures Number of MFCC features per window.
+ * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
+ * for an inference.
+ * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window.
+ * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window.
+ */
+ AsrPreProcess(TfLiteTensor* inputTensor,
+ uint32_t numMfccFeatures,
+ uint32_t numFeatureFrames,
+ uint32_t mfccWindowLen,
+ uint32_t mfccWindowStride);
/**
* @brief Calculates the features required from audio data. This
* includes MFCC, first and second order deltas,
* normalisation and finally, quantisation. The tensor is
- * populated with feature from a given window placed along
+ * populated with features from a given window placed along
* in a single row.
* @param[in] audioData Pointer to the first element of audio data.
* @param[in] audioDataLen Number of elements in the audio data.
- * @param[in] tensor Tensor to be populated.
* @return true if successful, false in case of error.
*/
- bool Invoke(const int16_t * audioData,
- uint32_t audioDataLen,
- TfLiteTensor * tensor);
+ bool DoPreProcess(const void* audioData, size_t audioDataLen) override;
protected:
/**
@@ -73,49 +68,32 @@ namespace asr {
* @param[in] mfcc MFCC buffers.
* @param[out] delta1 Result of the first diff computation.
* @param[out] delta2 Result of the second diff computation.
- *
- * @return true if successful, false otherwise.
+ * @return true if successful, false otherwise.
*/
static bool ComputeDeltas(Array2d<float>& mfcc,
Array2d<float>& delta1,
Array2d<float>& delta2);
/**
- * @brief Given a 2D vector of floats, computes the mean.
- * @param[in] vec Vector of vector of floats.
- * @return Mean value.
- */
- static float GetMean(Array2d<float>& vec);
-
- /**
- * @brief Given a 2D vector of floats, computes the stddev.
- * @param[in] vec Vector of vector of floats.
- * @param[in] mean Mean value of the vector passed in.
- * @return stddev value.
- */
- static float GetStdDev(Array2d<float>& vec,
- const float mean);
-
- /**
- * @brief Given a 2D vector of floats, normalises it using
- * the mean and the stddev
+ * @brief Given a 2D vector of floats, rescale it to have mean of 0 and
+ * standard deviation of 1.
* @param[in,out] vec Vector of vector of floats.
*/
- static void NormaliseVec(Array2d<float>& vec);
+ static void StandardizeVecF32(Array2d<float>& vec);
/**
- * @brief Normalises the MFCC and delta buffers.
+ * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1.
*/
- void Normalise();
+ void Standarize();
/**
* @brief Given the quantisation and data type limits, computes
* the quantised values of a floating point input data.
- * @param[in] elem Element to be quantised.
- * @param[in] quantScale Scale.
- * @param[in] quantOffset Offset.
- * @param[in] minVal Numerical limit - minimum.
- * @param[in] maxVal Numerical limit - maximum.
+ * @param[in] elem Element to be quantised.
+ * @param[in] quantScale Scale.
+ * @param[in] quantOffset Offset.
+ * @param[in] minVal Numerical limit - minimum.
+ * @param[in] maxVal Numerical limit - maximum.
* @return Floating point quantised value.
*/
static float GetQuantElem(
@@ -133,44 +111,43 @@ namespace asr {
* this being the convolution speed up (as we can use
* contiguous memory). The output, however, requires the
* time axis to be in column major arrangement.
- * @param[in] outputBuf Pointer to the output buffer.
- * @param[in] outputBufSz Output buffer's size.
- * @param[in] quantScale Quantisation scale.
- * @param[in] quantOffset Quantisation offset.
+ * @param[in] outputBuf Pointer to the output buffer.
+ * @param[in] outputBufSz Output buffer's size.
+ * @param[in] quantScale Quantisation scale.
+ * @param[in] quantOffset Quantisation offset.
*/
template <typename T>
bool Quantise(
- T * outputBuf,
+ T* outputBuf,
const uint32_t outputBufSz,
const float quantScale,
const int quantOffset)
{
- /* Check the output size will for everything. */
+ /* Check the output size will fit everything. */
if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) {
printf_err("Tensor size too small for features\n");
return false;
}
/* Populate. */
- T * outputBufMfcc = outputBuf;
- T * outputBufD1 = outputBuf + this->m_numMfccFeats;
- T * outputBufD2 = outputBufD1 + this->m_numMfccFeats;
+ T* outputBufMfcc = outputBuf;
+ T* outputBufD1 = outputBuf + this->m_numMfccFeats;
+ T* outputBufD2 = outputBufD1 + this->m_numMfccFeats;
const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */
const float minVal = std::numeric_limits<T>::min();
const float maxVal = std::numeric_limits<T>::max();
- /* We need to do a transpose while copying and concatenating
- * the tensor. */
- for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) {
+ /* Need to transpose while copying and concatenating the tensor. */
+ for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) {
for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) {
- *outputBufMfcc++ = static_cast<T>(this->GetQuantElem(
+ *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_mfccBuf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD1++ = static_cast<T>(this->GetQuantElem(
+ *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_delta1Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD2++ = static_cast<T>(this->GetQuantElem(
+ *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_delta2Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
}
@@ -183,24 +160,23 @@ namespace asr {
}
private:
- Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
/* Actual buffers to be populated. */
- Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
- Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
- Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
+ Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
+ Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
+ Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
- uint32_t m_windowLen; /* Window length for MFCC. */
- uint32_t m_windowStride; /* Window stride len for MFCC. */
- uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
- uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */
- AudioWindow m_window; /* Sliding window. */
+ uint32_t m_mfccWindowLen; /* Window length for MFCC. */
+ uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */
+ uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
+ uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */
+ AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_ASR_WAV2LET_PREPROC_HPP */ \ No newline at end of file
+#endif /* KWS_ASR_WAV2LETTER_PREPROCESS_HPP */ \ No newline at end of file