diff options
Diffstat (limited to 'source/use_case/kws')
-rw-r--r-- | source/use_case/kws/include/KwsProcessing.hpp | 23 | ||||
-rw-r--r-- | source/use_case/kws/src/KwsProcessing.cc | 5 | ||||
-rw-r--r-- | source/use_case/kws/src/UseCaseHandler.cc | 4 |
3 files changed, 14 insertions, 18 deletions
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp index abf20ab..ddf38c1 100644 --- a/source/use_case/kws/include/KwsProcessing.hpp +++ b/source/use_case/kws/include/KwsProcessing.hpp @@ -38,7 +38,7 @@ namespace app { public: /** * @brief Constructor - * @param[in] model Pointer to the the KWS Model object. + * @param[in] model Pointer to the KWS Model object. * @param[in] numFeatures How many MFCC features to use. * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when * sliding a window through the audio sample. @@ -107,24 +107,21 @@ namespace app { std::vector<ClassificationResult>& m_results; public: - const float m_scoreThreshold; /** - * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the the Image classification Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in] results Vector of classification results to store decoded outputs. - * @param[in] scoreThreshold Predicted model score must be larger than this value to be accepted. + * @brief Constructor + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] model Pointer to the KWS Model object. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. **/ KWSPostProcess(Classifier& classifier, Model* model, const std::vector<std::string>& labels, - std::vector<ClassificationResult>& results, - float scoreThreshold); + std::vector<ClassificationResult>& results); /** - * @brief Should perform post-processing of the result of inference then populate - * populate KWS result data for any later use. - * @return true if successful, false otherwise. + * @brief Should perform post-processing of the result of inference then + * populate KWS result data for any later use. + * @return true if successful, false otherwise. **/ bool DoPostProcess() override; }; diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc index b6b230c..14f9fce 100644 --- a/source/use_case/kws/src/KwsProcessing.cc +++ b/source/use_case/kws/src/KwsProcessing.cc @@ -197,11 +197,10 @@ namespace app { KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model, const std::vector<std::string>& labels, - std::vector<ClassificationResult>& results, float scoreThreshold) + std::vector<ClassificationResult>& results) :m_kwsClassifier{classifier}, m_labels{labels}, - m_results{results}, - m_scoreThreshold{scoreThreshold} + m_results{results} { if (!model->IsInited()) { printf_err("Model is not initialised!.\n"); diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 350d34b..e73a2c3 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -93,7 +93,7 @@ namespace app { std::vector<ClassificationResult> singleInfResult; KWSPostProcess postprocess = KWSPostProcess(ctx.Get<KwsClassifier &>("classifier"), &model, ctx.Get<std::vector<std::string>&>("labels"), - singleInfResult, scoreThreshold); + singleInfResult); UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); @@ -146,7 +146,7 @@ namespace app { /* Add results from this window to our final results vector. */ finalResults.emplace_back(kws::KwsResult(singleInfResult, audioDataSlider.Index() * secondsPerSample * preprocess.m_audioDataStride, - audioDataSlider.Index(), postprocess.m_scoreThreshold)); + audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT TfLiteTensor* outputTensor = model.GetOutputTensor(0); |