summaryrefslogtreecommitdiff
path: root/source/use_case/asr/src/AsrClassifier.cc
diff options
context:
space:
mode:
authoralexander <alexander.efremov@arm.com>2021-04-29 20:36:09 +0100
committerAlexander Efremov <alexander.efremov@arm.com>2021-05-04 19:57:44 +0000
commitc350cdced0a8a2ca17376f58813e6d48d796ac7c (patch)
treef732cde664837a7cb9429b17e1366bb31a635b15 /source/use_case/asr/src/AsrClassifier.cc
parent6448932cc1c612d78e62c778ebb228b3cbe96a58 (diff)
downloadml-embedded-evaluation-kit-c350cdced0a8a2ca17376f58813e6d48d796ac7c.tar.gz
MLECO-1868: Code static analyzer warnings fixes
Signed-off-by: alexander <alexander.efremov@arm.com> Change-Id: Ie423e9cad3fabec6ab077ded7236813fe4933dea
Diffstat (limited to 'source/use_case/asr/src/AsrClassifier.cc')
-rw-r--r--source/use_case/asr/src/AsrClassifier.cc39
1 files changed, 22 insertions, 17 deletions
diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc
index 7377d30..df26a7f 100644
--- a/source/use_case/asr/src/AsrClassifier.cc
+++ b/source/use_case/asr/src/AsrClassifier.cc
@@ -21,13 +21,18 @@
#include "Wav2LetterModel.hpp"
template<typename T>
-bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint)
+bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint)
{
const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
+ if (nLetters != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
+
/* NOTE: tensor's size verification against labels should be
* checked by the calling/public function. */
if (nLetters < 1) {
@@ -58,12 +63,12 @@ bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor,
return true;
}
-template bool arm::app::AsrClassifier::_GetTopResults<uint8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-template bool arm::app::AsrClassifier::_GetTopResults<int8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
+template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint);
+template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint);
bool arm::app::AsrClassifier::GetClassificationResults(
TfLiteTensor* outputTensor,
@@ -104,16 +109,16 @@ bool arm::app::AsrClassifier::GetClassificationResults(
switch (outputTensor->type) {
case kTfLiteUInt8:
- resultState = this->_GetTopResults<uint8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
+ resultState = this->GetTopResults<uint8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
break;
case kTfLiteInt8:
- resultState = this->_GetTopResults<int8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
+ resultState = this->GetTopResults<int8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
break;
default:
printf_err("Tensor type %s not supported by classifier\n",