summaryrefslogtreecommitdiff
path: root/source/use_case/kws/include/KwsProcessing.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws/include/KwsProcessing.hpp')
-rw-r--r--source/use_case/kws/include/KwsProcessing.hpp38
1 files changed, 22 insertions, 16 deletions
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp
index ddf38c1..d3de3b3 100644
--- a/source/use_case/kws/include/KwsProcessing.hpp
+++ b/source/use_case/kws/include/KwsProcessing.hpp
@@ -33,18 +33,21 @@ namespace app {
* Implements methods declared by BasePreProcess and anything else needed
* to populate input tensors ready for inference.
*/
- class KWSPreProcess : public BasePreProcess {
+ class KwsPreProcess : public BasePreProcess {
public:
/**
* @brief Constructor
- * @param[in] model Pointer to the KWS Model object.
- * @param[in] numFeatures How many MFCC features to use.
- * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
- * sliding a window through the audio sample.
- * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] numFeatures How many MFCC features to use.
+ * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
+ * for an inference.
+ * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
+ * sliding a window through the audio sample.
+ * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
**/
- explicit KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride);
+ explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames,
+ int mfccFrameLength, int mfccFrameStride);
/**
* @brief Should perform pre-processing of 'raw' input audio data and load it into
@@ -60,8 +63,10 @@ namespace app {
size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
private:
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
const int m_mfccFrameLength;
const int m_mfccFrameStride;
+ const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */
audio::MicroNetKwsMFCC m_mfcc;
audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
@@ -99,22 +104,23 @@ namespace app {
* Implements methods declared by BasePostProcess and anything else needed
* to populate result vector.
*/
- class KWSPostProcess : public BasePostProcess {
+ class KwsPostProcess : public BasePostProcess {
private:
- Classifier& m_kwsClassifier;
- const std::vector<std::string>& m_labels;
- std::vector<ClassificationResult>& m_results;
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ Classifier& m_kwsClassifier; /* KWS Classifier object. */
+ const std::vector<std::string>& m_labels; /* KWS Labels. */
+ std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */
public:
/**
* @brief Constructor
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] model Pointer to the KWS Model object.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in/out] results Vector of classification results to store decoded outputs.
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] results Vector of classification results to store decoded outputs.
**/
- KWSPostProcess(Classifier& classifier, Model* model,
+ KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results);