diff options
author | Richard Burton <richard.burton@arm.com> | 2022-04-22 16:14:57 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-04-22 16:14:57 +0100 |
commit | b40ecf8522052809d2351677a96195d69e4d0c16 (patch) | |
tree | 8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/kws/include/KwsProcessing.hpp | |
parent | c291144b7f08c21d08cdaf79cc64dc420ca70070 (diff) | |
download | ml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz |
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments
Removed the inference runner code as it wasn't used
Fixes to doc strings
Consistent naming e.g. Asr/Kws instead of ASR/KWS
Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/kws/include/KwsProcessing.hpp')
-rw-r--r-- | source/use_case/kws/include/KwsProcessing.hpp | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp index ddf38c1..d3de3b3 100644 --- a/source/use_case/kws/include/KwsProcessing.hpp +++ b/source/use_case/kws/include/KwsProcessing.hpp @@ -33,18 +33,21 @@ namespace app { * Implements methods declared by BasePreProcess and anything else needed * to populate input tensors ready for inference. */ - class KWSPreProcess : public BasePreProcess { + class KwsPreProcess : public BasePreProcess { public: /** * @brief Constructor - * @param[in] model Pointer to the KWS Model object. - * @param[in] numFeatures How many MFCC features to use. - * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when - * sliding a window through the audio sample. - * @param[in] mfccFrameStride Number of audio samples between consecutive windows. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numFeatures How many MFCC features to use. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when + * sliding a window through the audio sample. + * @param[in] mfccFrameStride Number of audio samples between consecutive windows. **/ - explicit KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride); + explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames, + int mfccFrameLength, int mfccFrameStride); /** * @brief Should perform pre-processing of 'raw' input audio data and load it into @@ -60,8 +63,10 @@ namespace app { size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */ private: + TfLiteTensor* m_inputTensor; /* Model input tensor. */ const int m_mfccFrameLength; const int m_mfccFrameStride; + const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */ audio::MicroNetKwsMFCC m_mfcc; audio::SlidingWindow<const int16_t> m_mfccSlidingWindow; @@ -99,22 +104,23 @@ namespace app { * Implements methods declared by BasePostProcess and anything else needed * to populate result vector. */ - class KWSPostProcess : public BasePostProcess { + class KwsPostProcess : public BasePostProcess { private: - Classifier& m_kwsClassifier; - const std::vector<std::string>& m_labels; - std::vector<ClassificationResult>& m_results; + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + Classifier& m_kwsClassifier; /* KWS Classifier object. */ + const std::vector<std::string>& m_labels; /* KWS Labels. */ + std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */ public: /** * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the KWS Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in/out] results Vector of classification results to store decoded outputs. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. **/ - KWSPostProcess(Classifier& classifier, Model* model, + KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results); |