diff options
Diffstat (limited to 'source/use_case/asr')
-rw-r--r-- | source/use_case/asr/include/AsrClassifier.hpp | 3 | ||||
-rw-r--r-- | source/use_case/asr/src/AsrClassifier.cc | 3 |
2 files changed, 4 insertions, 2 deletions
diff --git a/source/use_case/asr/include/AsrClassifier.hpp b/source/use_case/asr/include/AsrClassifier.hpp index 2c97a39..67a200e 100644 --- a/source/use_case/asr/include/AsrClassifier.hpp +++ b/source/use_case/asr/include/AsrClassifier.hpp @@ -32,12 +32,13 @@ namespace app { * populated by this function. * @param[in] labels Labels vector to match classified classes * @param[in] topNCount Number of top classifications to pick. + * @param[in] use_softmax Whether softmax scaling should be applied to model output. * @return true if successful, false otherwise. **/ bool GetClassificationResults( TfLiteTensor* outputTensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount) override; + const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax = false) override; private: /** diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc index c18bd88..a715068 100644 --- a/source/use_case/asr/src/AsrClassifier.cc +++ b/source/use_case/asr/src/AsrClassifier.cc @@ -73,8 +73,9 @@ template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tenso bool arm::app::AsrClassifier::GetClassificationResults( TfLiteTensor* outputTensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount) + const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax) { + UNUSED(use_softmax); vecResults.clear(); constexpr int minTensorDims = static_cast<int>( |