diff options
Diffstat (limited to 'source/use_case/kws/include/KwsProcessing.hpp')
-rw-r--r-- | source/use_case/kws/include/KwsProcessing.hpp | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp index ddf38c1..d3de3b3 100644 --- a/source/use_case/kws/include/KwsProcessing.hpp +++ b/source/use_case/kws/include/KwsProcessing.hpp @@ -33,18 +33,21 @@ namespace app { * Implements methods declared by BasePreProcess and anything else needed * to populate input tensors ready for inference. */ - class KWSPreProcess : public BasePreProcess { + class KwsPreProcess : public BasePreProcess { public: /** * @brief Constructor - * @param[in] model Pointer to the KWS Model object. - * @param[in] numFeatures How many MFCC features to use. - * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when - * sliding a window through the audio sample. - * @param[in] mfccFrameStride Number of audio samples between consecutive windows. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numFeatures How many MFCC features to use. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when + * sliding a window through the audio sample. + * @param[in] mfccFrameStride Number of audio samples between consecutive windows. **/ - explicit KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride); + explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames, + int mfccFrameLength, int mfccFrameStride); /** * @brief Should perform pre-processing of 'raw' input audio data and load it into @@ -60,8 +63,10 @@ namespace app { size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */ private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ const int m_mfccFrameLength; const int m_mfccFrameStride; + const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */ audio::MicroNetKwsMFCC m_mfcc; audio::SlidingWindow<const int16_t> m_mfccSlidingWindow; @@ -99,22 +104,23 @@ namespace app { * Implements methods declared by BasePostProcess and anything else needed * to populate result vector. */ - class KWSPostProcess : public BasePostProcess { + class KwsPostProcess : public BasePostProcess { private: - Classifier& m_kwsClassifier; - const std::vector<std::string>& m_labels; - std::vector<ClassificationResult>& m_results; + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + Classifier& m_kwsClassifier; /* KWS Classifier object. */ + const std::vector<std::string>& m_labels; /* KWS Labels. */ + std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */ public: /** * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the KWS Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in/out] results Vector of classification results to store decoded outputs. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. **/ - KWSPostProcess(Classifier& classifier, Model* model, + KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results); |