diff options
Diffstat (limited to 'source/use_case/kws_asr/src/AsrClassifier.cc')
-rw-r--r-- | source/use_case/kws_asr/src/AsrClassifier.cc | 40 |
1 files changed, 22 insertions, 18 deletions
diff --git a/source/use_case/kws_asr/src/AsrClassifier.cc b/source/use_case/kws_asr/src/AsrClassifier.cc index bc86e09..f1fa6f1 100644 --- a/source/use_case/kws_asr/src/AsrClassifier.cc +++ b/source/use_case/kws_asr/src/AsrClassifier.cc @@ -21,13 +21,17 @@ #include "Wav2LetterModel.hpp" template<typename T> -bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint) +bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector<ClassificationResult>& vecResults, + const std::vector <std::string>& labels, double scale, double zeroPoint) { const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx]; const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]; + if (nLetters != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } /* NOTE: tensor's size verification against labels should be * checked by the calling/public function. */ @@ -42,7 +46,7 @@ bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, /* Get the top 1 results. */ for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { - std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0); + std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row], 0); for (uint32_t j = 1; j < nLetters; ++j) { if (top_1.first < tensorData[row + j]) { @@ -59,12 +63,12 @@ bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, return true; } -template bool arm::app::AsrClassifier::_GetTopResults<uint8_t>(TfLiteTensor* tensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint); -template bool arm::app::AsrClassifier::_GetTopResults<int8_t>(TfLiteTensor* tensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint); +template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor, + std::vector<ClassificationResult>& vecResults, + const std::vector <std::string>& labels, double scale, double zeroPoint); +template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor, + std::vector<ClassificationResult>& vecResults, + const std::vector <std::string>& labels, double scale, double zeroPoint); bool arm::app::AsrClassifier::GetClassificationResults( TfLiteTensor* outputTensor, @@ -105,16 +109,16 @@ bool arm::app::AsrClassifier::GetClassificationResults( switch (outputTensor->type) { case kTfLiteUInt8: - resultState = this->_GetTopResults<uint8_t>( - outputTensor, vecResults, - labels, quantParams.scale, - quantParams.offset); + resultState = this->GetTopResults<uint8_t>( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); break; case kTfLiteInt8: - resultState = this->_GetTopResults<int8_t>( - outputTensor, vecResults, - labels, quantParams.scale, - quantParams.offset); + resultState = this->GetTopResults<int8_t>( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); break; default: printf_err("Tensor type %s not supported by classifier\n", |