summaryrefslogtreecommitdiff
path: root/source/use_case/asr/src/AsrClassifier.cc
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
commitb40ecf8522052809d2351677a96195d69e4d0c16 (patch)
tree8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/asr/src/AsrClassifier.cc
parentc291144b7f08c21d08cdaf79cc64dc420ca70070 (diff)
downloadml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/asr/src/AsrClassifier.cc')
-rw-r--r--source/use_case/asr/src/AsrClassifier.cc196
1 files changed, 102 insertions, 94 deletions
diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc
index 84e66b7..4ba8c7b 100644
--- a/source/use_case/asr/src/AsrClassifier.cc
+++ b/source/use_case/asr/src/AsrClassifier.cc
@@ -20,117 +20,125 @@
#include "TensorFlowLiteMicro.hpp"
#include "Wav2LetterModel.hpp"
-template<typename T>
-bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint)
-{
- const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
- const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
-
- if (nLetters != labels.size()) {
- printf("Output size doesn't match the labels' size\n");
- return false;
- }
+namespace arm {
+namespace app {
+
+ template<typename T>
+ bool AsrClassifier::GetTopResults(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint)
+ {
+ const uint32_t nElems = tensor->dims->data[Wav2LetterModel::ms_outputRowsIdx];
+ const uint32_t nLetters = tensor->dims->data[Wav2LetterModel::ms_outputColsIdx];
+
+ if (nLetters != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
- /* NOTE: tensor's size verification against labels should be
- * checked by the calling/public function. */
- if (nLetters < 1) {
- return false;
- }
+ /* NOTE: tensor's size verification against labels should be
+ * checked by the calling/public function. */
+ if (nLetters < 1) {
+ return false;
+ }
- /* Final results' container. */
- vecResults = std::vector<ClassificationResult>(nElems);
+ /* Final results' container. */
+ vecResults = std::vector<ClassificationResult>(nElems);
- T* tensorData = tflite::GetTensorData<T>(tensor);
+ T* tensorData = tflite::GetTensorData<T>(tensor);
- /* Get the top 1 results. */
- for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
- std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
+ /* Get the top 1 results. */
+ for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
+ std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
- for (uint32_t j = 1; j < nLetters; ++j) {
- if (top_1.first < tensorData[row + j]) {
- top_1.first = tensorData[row + j];
- top_1.second = j;
+ for (uint32_t j = 1; j < nLetters; ++j) {
+ if (top_1.first < tensorData[row + j]) {
+ top_1.first = tensorData[row + j];
+ top_1.second = j;
+ }
}
+
+ double score = static_cast<int> (top_1.first);
+ vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
+ vecResults[i].m_label = labels[top_1.second];
+ vecResults[i].m_labelIdx = top_1.second;
}
- double score = static_cast<int> (top_1.first);
- vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
- vecResults[i].m_label = labels[top_1.second];
- vecResults[i].m_labelIdx = top_1.second;
+ return true;
}
-
- return true;
-}
-template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-
-bool arm::app::AsrClassifier::GetClassificationResults(
+ template bool AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels,
+ double scale, double zeroPoint);
+ template bool AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels,
+ double scale, double zeroPoint);
+
+ bool AsrClassifier::GetClassificationResults(
TfLiteTensor* outputTensor,
std::vector<ClassificationResult>& vecResults,
const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax)
-{
- UNUSED(use_softmax);
- vecResults.clear();
+ {
+ UNUSED(use_softmax);
+ vecResults.clear();
- constexpr int minTensorDims = static_cast<int>(
- (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)?
- arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx);
+ constexpr int minTensorDims = static_cast<int>(
+ (Wav2LetterModel::ms_outputRowsIdx > Wav2LetterModel::ms_outputColsIdx)?
+ Wav2LetterModel::ms_outputRowsIdx : Wav2LetterModel::ms_outputColsIdx);
- constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx;
+ constexpr uint32_t outColsIdx = Wav2LetterModel::ms_outputColsIdx;
- /* Sanity checks. */
- if (outputTensor == nullptr) {
- printf_err("Output vector is null pointer.\n");
- return false;
- } else if (outputTensor->dims->size < minTensorDims) {
- printf_err("Output tensor expected to be %dD\n", minTensorDims);
- return false;
- } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
- printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
- return false;
- } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
- printf("Output size doesn't match the labels' size\n");
- return false;
- }
+ /* Sanity checks. */
+ if (outputTensor == nullptr) {
+ printf_err("Output vector is null pointer.\n");
+ return false;
+ } else if (outputTensor->dims->size < minTensorDims) {
+ printf_err("Output tensor expected to be %dD\n", minTensorDims);
+ return false;
+ } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
+ printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
+ return false;
+ } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
- if (topNCount != 1) {
- warn("TopNCount value ignored in this implementation\n");
- }
+ if (topNCount != 1) {
+ warn("TopNCount value ignored in this implementation\n");
+ }
- /* To return the floating point values, we need quantization parameters. */
- QuantParams quantParams = GetTensorQuantParams(outputTensor);
-
- bool resultState;
-
- switch (outputTensor->type) {
- case kTfLiteUInt8:
- resultState = this->GetTopResults<uint8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
- break;
- case kTfLiteInt8:
- resultState = this->GetTopResults<int8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
- break;
- default:
- printf_err("Tensor type %s not supported by classifier\n",
- TfLiteTypeGetName(outputTensor->type));
+ /* To return the floating point values, we need quantization parameters. */
+ QuantParams quantParams = GetTensorQuantParams(outputTensor);
+
+ bool resultState;
+
+ switch (outputTensor->type) {
+ case kTfLiteUInt8:
+ resultState = this->GetTopResults<uint8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
+ break;
+ case kTfLiteInt8:
+ resultState = this->GetTopResults<int8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
+ break;
+ default:
+ printf_err("Tensor type %s not supported by classifier\n",
+ TfLiteTypeGetName(outputTensor->type));
+ return false;
+ }
+
+ if (!resultState) {
+ printf_err("Failed to get sorted set\n");
return false;
- }
+ }
- if (!resultState) {
- printf_err("Failed to get sorted set\n");
- return false;
- }
+ return true;
+ }
- return true;
-} \ No newline at end of file
+} /* namespace app */
+} /* namespace arm */ \ No newline at end of file