summaryrefslogtreecommitdiff
path: root/source/use_case/asr/src/AsrClassifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/asr/src/AsrClassifier.cc')
-rw-r--r--source/use_case/asr/src/AsrClassifier.cc39
1 files changed, 22 insertions, 17 deletions
diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc
index 7377d30..df26a7f 100644
--- a/source/use_case/asr/src/AsrClassifier.cc
+++ b/source/use_case/asr/src/AsrClassifier.cc
@@ -21,13 +21,18 @@
#include "Wav2LetterModel.hpp"
template<typename T>
-bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint)
+bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint)
{
const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
+ if (nLetters != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
+
/* NOTE: tensor's size verification against labels should be
* checked by the calling/public function. */
if (nLetters < 1) {
@@ -58,12 +63,12 @@ bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor,
return true;
}
-template bool arm::app::AsrClassifier::_GetTopResults<uint8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-template bool arm::app::AsrClassifier::_GetTopResults<int8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
+template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint);
+template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint);
bool arm::app::AsrClassifier::GetClassificationResults(
TfLiteTensor* outputTensor,
@@ -104,16 +109,16 @@ bool arm::app::AsrClassifier::GetClassificationResults(
switch (outputTensor->type) {
case kTfLiteUInt8:
- resultState = this->_GetTopResults<uint8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
+ resultState = this->GetTopResults<uint8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
break;
case kTfLiteInt8:
- resultState = this->_GetTopResults<int8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
+ resultState = this->GetTopResults<int8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
break;
default:
printf_err("Tensor type %s not supported by classifier\n",